Skip to content

Commit

Permalink
fno layer
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 17, 2023
1 parent 42c914a commit 452af90
Showing 1 changed file with 98 additions and 3 deletions.
101 changes: 98 additions & 3 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def __init__(self, name):
Unique string identifier of the skip connection. The skip endpoint
should have the same name.
"""
super().__init__(name=name)
super().__init__()
self._name = name
self._cache = None

def call(self, x):
Expand Down Expand Up @@ -597,6 +598,100 @@ def call(self, x):
return x


class FourierNeuralOperator(tf.keras.layers.Layer):
"""Custom layer for fourier neural operator block
Note that this is only set up to take a channels-last input
References
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
"""

def __init__(self, ratio=16, sparsity_threshold=0.5):
"""
Parameters
----------
ratio : int
Number of channels/filters divided by the number of
dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function.
"""

super().__init__()
self._ratio = ratio
self.fft_layer = None
self.ifft_layer = None
self.mlp_layers = None
self.sparsity_threshold = sparsity_threshold

def softshrink(self, x, lambd=0.5):
"""Softshrink activation function
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
x = tf.convert_to_tensor(x)
values_below_lower = tf.where(x < -lambd, x + lambd, 0)
values_above_upper = tf.where(lambd < x, x - lambd, 0)
return values_below_lower + values_above_upper

def build(self, input_shape):
"""Build the FNO layer based on an input shape
Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""

self._n_channels = input_shape[-1]
self._dense_units = int(np.ceil(self._n_channels / self._ratio))

if len(input_shape) == 4:
self.fft_layer = tf.signal.fft2d
self.ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self.fft_layer = tf.signal.fft3d
self.ifft_layer = tf.signal.ifft3d
else:
msg = ('FourierNeuralOperator layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)

self.mlp_layers = [
tf.keras.layers.Dense(self._dense_units, activation='relu'),
tf.keras.layers.Dense(self._n_channels)]

def call(self, x):
"""Call the custom FourierNeuralOperator layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor, this is the FNO weights added to the original input
tensor.
"""

t_in = x
x = self.fft_layer(x)
for layer in self.mlp_layers:
x = layer(x)
x = self.softshrink(x, lambd=self.sparsity_threshold)
x = self.ifft_layer(x)

return x + t_in


class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""
Expand All @@ -609,7 +704,7 @@ def __init__(self, name=None):
Unique str identifier of the adder layer. Usually the name of the
hi-resolution feature used in the addition.
"""
super().__init__(name=name)
self.name = name

def call(self, x, hi_res_adder):
"""Adds hi-resolution data to the input tensor x in the middle of a
Expand Down Expand Up @@ -644,7 +739,7 @@ def __init__(self, name=None):
Unique str identifier for the concat layer. Usually the name of the
hi-resolution feature used in the concatenation.
"""
super().__init__(name=name)
self.name = name

def call(self, x, hi_res_feature):
"""Concatenates a hi-resolution feature to the input tensor x in the
Expand Down

0 comments on commit 452af90

Please sign in to comment.