Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add AEBiGRUNetwork #1583

Merged
merged 14 commits into from
Jul 11, 2024
2 changes: 2 additions & 0 deletions aeon/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"AEFCNNetwork",
"AEResNetNetwork",
"LITENetwork",
"AEBiGRUNetwork",
]
from aeon.networks._ae_bgru import AEBiGRUNetwork
from aeon.networks._ae_fcn import AEFCNNetwork
from aeon.networks._ae_resnet import AEResNetNetwork
from aeon.networks._cnn import CNNNetwork
Expand Down
142 changes: 142 additions & 0 deletions aeon/networks/_ae_bgru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Implement Auto-Encoder based on Bidirectional GRUs."""

from aeon.networks.base import BaseDeepLearningNetwork


class AEBiGRUNetwork(BaseDeepLearningNetwork):
"""
A class to implement an Auto-Encoder based on Bidirectional GRUs.

Parameters
----------
latent_space_dim : int, default=128
aadya940 marked this conversation as resolved.
Show resolved Hide resolved
Dimension of the latent space.
n_layers : int, default=None
Number of BiGRU layers. If None, defaults will be used.
n_units : list
Number of units in each BiGRU layer. If None, defaults will be used.
activation : Union[list, str]
Activation function(s) to use in each layer.
Can be a single string or a list.
temporal_latent_space : bool, default = False
Flag to choose whether the latent space is an MTS or Euclidean space.
"""

_config = {
"python_dependencies": ["tensorflow"],
"python_version": "<3.12",
"structure": "encoder",
}

def __init__(
self,
latent_space_dim=128,
n_layers=None,
n_units=None,
activation="relu",
temporal_latent_space=False,
):
super().__init__()

self.latent_space_dim = latent_space_dim
self.activation = activation
self.n_layers = n_layers
self.n_units = n_units
self.temporal_latent_space = temporal_latent_space

def build_network(self, input_shape, **kwargs):
"""Construct a network and return its input and output layers.

Parameters
----------
input_shape : tuple of shape = (n_timepoints (m), n_channels (d))
The shape of the data fed into the input layer.

Returns
-------
encoder : a keras Model.
decoder : a keras Model.
"""
import tensorflow as tf

if self.n_layers is None:
if self.n_units is not None:
raise ValueError(
"""Cannot pass number of units without specifying
number of layers."""
)
elif self.n_units is None:
self._n_layers, self._n_units = 2, [50, self.latent_space_dim // 2]
elif self.n_layers is not None:
self._n_layers = self.n_layers
if self.n_units is None:
self._n_units = [50 for _ in range(self.n_layers)]
self._n_units[-1] = self.latent_space_dim // 2
elif self.n_units is not None:
if isinstance(self.n_units, list):
self._n_units = self.n_units
self._n_units[-1] = self.latent_space_dim // 2
assert len(self.n_units) == self.n_layers
elif isinstance(self.n_units, int):
self._n_units = [self.n_units for _ in range(self.n_layers)]
self._n_units[-1] = self.latent_space_dim // 2

if isinstance(self.activation, str):
self._activation = [self.activation for _ in range(self._n_layers)]
else:
self._activation = self.activation
assert isinstance(self.activation, list)
assert len(self.activation) == self._n_layers

encoder_inputs = tf.keras.layers.Input(shape=input_shape, name="encoder_input")
x = encoder_inputs
for i in range(self._n_layers):
return_sequences = i < self._n_layers - 1
if self.temporal_latent_space:
return_sequences = i < self._n_layers
x = tf.keras.layers.Bidirectional(
tf.keras.layers.GRU(
units=self._n_units[i],
activation=self._activation[i],
return_sequences=return_sequences,
),
name=f"encoder_bgru_{i+1}",
)(x)

latent_space = tf.keras.layers.Dense(
self.latent_space_dim, activation="linear", name="latent_space"
)(x)
encoder_model = tf.keras.models.Model(
inputs=encoder_inputs, outputs=latent_space, name="encoder"
)

if not self.temporal_latent_space:
decoder_inputs = tf.keras.layers.Input(
shape=(self.latent_space_dim,), name="decoder_input"
)
x = tf.keras.layers.RepeatVector(input_shape[0], name="repeat_vector")(
decoder_inputs
)
elif self.temporal_latent_space:
decoder_inputs = tf.keras.layers.Input(
shape=latent_space.shape[1:], name="decoder_input"
)
x = decoder_inputs

for i in range(self._n_layers - 1, -1, -1):
x = tf.keras.layers.Bidirectional(
tf.keras.layers.GRU(
units=self._n_units[i],
activation=self._activation[i],
return_sequences=True,
),
name=f"decoder_bgru_{i+1}",
)(x)
decoder_outputs = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(input_shape[1]), name="decoder_output"
)(x)
decoder_model = tf.keras.models.Model(
inputs=decoder_inputs, outputs=decoder_outputs, name="decoder"
)

return encoder_model, decoder_model
54 changes: 54 additions & 0 deletions aeon/networks/tests/test_ae_bgru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Tests for the AEBiGRU Model."""

import random

import pytest

from aeon.networks import AEBiGRUNetwork
hadifawaz1999 marked this conversation as resolved.
Show resolved Hide resolved
from aeon.utils.validation._dependencies import _check_soft_dependencies


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="Tensorflow soft dependency unavailable.",
)
@pytest.mark.parametrize(
"latent_space_dim,n_layers,temporal_latent_space",
[
(32, 1, True),
(128, 2, False),
(256, 3, True),
(64, 4, False),
],
)
def test_aebigrunetwork_init(latent_space_dim, n_layers, temporal_latent_space):
"""Test whether AEBiGRUNetwork initializes correctly for various parameters."""
aebigru = AEBiGRUNetwork(
latent_space_dim=latent_space_dim,
n_layers=n_layers,
temporal_latent_space=temporal_latent_space,
activation=random.choice(["relu", "tanh"]),
n_units=[random.choice([50, 25, 100]) for _ in range(n_layers)],
)
encoder, decoder = aebigru.build_network((1000, 5))
assert encoder is not None
assert decoder is not None


@pytest.mark.skipif(
not _check_soft_dependencies(["tensorflow"], severity="none"),
reason="Tensorflow soft dependency unavailable.",
)
@pytest.mark.parametrize("activation", ["relu", "tanh"])
def test_aebigrunetwork_activations(activation):
"""Test whether AEBiGRUNetwork initializes correctly with different activations."""
aebigru = AEBiGRUNetwork(
latent_space_dim=64,
n_layers=2,
temporal_latent_space=True,
activation=activation,
n_units=[50, 50],
)
encoder, decoder = aebigru.build_network((1000, 5))
assert encoder is not None
assert decoder is not None