diff --git a/deeptrack/layers.py b/deeptrack/layers.py index cf380afb3..d62de25df 100644 --- a/deeptrack/layers.py +++ b/deeptrack/layers.py @@ -1,3 +1,4 @@ # For backwards compatability from .models.layers import * +from .models.embeddings import * from .models.gnns.layers import * \ No newline at end of file diff --git a/deeptrack/models/embeddings.py b/deeptrack/models/embeddings.py new file mode 100644 index 000000000..8d7fbbdd3 --- /dev/null +++ b/deeptrack/models/embeddings.py @@ -0,0 +1,148 @@ +import tensorflow as tf + +from .utils import single_layer_call +from .layers import register + + +class ClassToken(tf.keras.layers.Layer): + """ClassToken Layer.""" + + def build(self, input_shape): + cls_init = tf.zeros_initializer() + self.hidden_size = input_shape[-1] + self.cls = tf.Variable( + name="cls", + initial_value=cls_init(shape=(1, 1, self.hidden_size), dtype="float32"), + trainable=True, + ) + + def call(self, inputs): + batch_size = tf.shape(inputs)[0] + cls_broadcasted = tf.cast( + tf.broadcast_to(self.cls, [batch_size, 1, self.hidden_size]), + dtype=inputs.dtype, + ) + return tf.concat([cls_broadcasted, inputs], 1) + + +@register("ClassToken") +def ClassTokenLayer(activation=None, normalization=None, norm_kwargs={}, **kwargs): + """ClassToken Layer that append a class token to the input. + + Can optionally perform normalization or some activation function. + + Accepts arguments of keras.layers.Layer. + + Parameters + ---------- + + activation : str or activation function or layer + Activation function of the layer. See keras docs for accepted strings. + normalization : str or normalization function or layer + Normalization function of the layer. See keras and tfa docs for accepted strings. + norm_kwargs : dict + Arguments for the normalization function. + **kwargs + Other keras.layers.Layer arguments + """ + + def Layer(filters, **kwargs_inner): + kwargs_inner.update(kwargs) + layer = ClassToken(**kwargs_inner) + return lambda x: single_layer_call( + x, layer, activation, normalization, norm_kwargs + ) + + return Layer + + +class LearnablePositionEmbs(tf.keras.layers.Layer): + """Adds or concatenates positional embeddings to the inputs. + Parameters + ---------- + initializer : str or tf.keras.initializers.Initializer + Initializer function for the embeddings. See tf.keras.initializers.Initializer for accepted functions. + concat : bool + Whether to concatenate the positional embeddings to the inputs. If False, + adds the positional embeddings to the inputs. + kwargs: dict + Other arguments for the keras.layers.Layer + """ + + def __init__( + self, + initializer=tf.keras.initializers.RandomNormal(stddev=0.06), + concat=False, + **kwargs, + ): + super().__init__(**kwargs) + self.concat = concat + + assert initializer is callable or isinstance( + initializer, tf.keras.initializers.Initializer + ), "initial_value must be callable or a tf.keras.initializers.Initializer" + self.initializer = initializer + + def build(self, input_shape): + assert ( + len(input_shape) == 3 + ), f"Number of dimensions should be 3, got {len(input_shape)}" + self.pos_embedding = tf.Variable( + name="pos_embedding", + initial_value=self.initializer(shape=(1, *(input_shape[-2:]))), + dtype="float32", + trainable=True, + ) + + def call(self, inputs): + if self.concat: + return tf.concat( + [inputs, tf.cast(self.pos_embedding, dtype=inputs.dtype)], axis=-1 + ) + else: + return inputs + tf.cast(self.pos_embedding, dtype=inputs.dtype) + + +@register("LearnablePositionEmbs") +def LearnablePositionEmbsLayer( + initializer=tf.keras.initializers.RandomNormal(stddev=0.06), + concat=False, + activation=None, + normalization=None, + norm_kwargs={}, + **kwargs, +): + """Adds or concatenates positional embeddings to the inputs. + + Can optionally perform normalization or some activation function. + + Accepts arguments of keras.layers.Layer. + + Parameters + ---------- + + initializer : str or tf.keras.initializers.Initializer + Initializer function for the embeddings. See tf.keras.initializers.Initializer for accepted functions. + concat : bool + Whether to concatenate the positional embeddings to the inputs. If False, + adds the positional embeddings to the inputs. + activation : str or activation function or layer + Activation function of the layer. See keras docs for accepted strings. + normalization : str or normalization function or layer + Normalization function of the layer. See keras and tfa docs for accepted strings. + norm_kwargs : dict + Arguments for the normalization function. + **kwargs + Other arguments for the keras.layers.Layer + """ + + def Layer(filters, **kwargs_inner): + kwargs_inner.update(kwargs) + layer = LearnablePositionEmbs( + initializer=initializer, concat=concat, **kwargs_inner + ) + return lambda x: single_layer_call( + x, layer, activation, normalization, norm_kwargs + ) + + return Layer diff --git a/deeptrack/models/gans/cgan.py b/deeptrack/models/gans/cgan.py index 05e70eaed..6fa24f1b8 100644 --- a/deeptrack/models/gans/cgan.py +++ b/deeptrack/models/gans/cgan.py @@ -1,5 +1,5 @@ import tensorflow as tf -from .utils import as_KerasModel +from ..utils import as_KerasModel layers = tf.keras.layers diff --git a/deeptrack/models/gans/gan.py b/deeptrack/models/gans/gan.py index 08a76cafb..ec1b5789c 100644 --- a/deeptrack/models/gans/gan.py +++ b/deeptrack/models/gans/gan.py @@ -1,5 +1,5 @@ import tensorflow as tf -from .utils import as_KerasModel +from ..utils import as_KerasModel layers = tf.keras.layers diff --git a/deeptrack/models/gans/pcgan.py b/deeptrack/models/gans/pcgan.py index 78fdb9bf5..1bb848d65 100644 --- a/deeptrack/models/gans/pcgan.py +++ b/deeptrack/models/gans/pcgan.py @@ -1,5 +1,5 @@ import tensorflow as tf -from .utils import as_KerasModel +from ..utils import as_KerasModel layers = tf.keras.layers @@ -18,20 +18,20 @@ class PCGAN(tf.keras.Model): discriminator_loss: str or keras loss function The loss function of the discriminator network discriminator_optimizer: str or keras optimizer - The optimizer of the discriminator network + The optimizer of the discriminator network discriminator_metrics: list, optional List of metrics to be evaluated by the discriminator - model during training and testing + model during training and testing assemble_loss: list of str or keras loss functions List of loss functions to be evaluated on each output of the assemble model (stacked generator and discriminator), such as `assemble_loss = ["mse", "mse", "mae"]` for the prediction of the discriminator, the predicted - perceptual features, and the generated image, respectively + perceptual features, and the generated image, respectively assemble_optimizer: str or keras optimizer - The optimizer of the assemble network + The optimizer of the assemble network assemble_loss_weights: list or dict, optional - List or dictionary specifying scalar coefficients (floats) + List or dictionary specifying scalar coefficients (floats) to weight the loss contributions of the assemble model outputs perceptual_discriminator: str or keras model Name of the perceptual discriminator. Select the name of this network @@ -43,8 +43,8 @@ class PCGAN(tf.keras.Model): ImageNet weights, or provide the path to the weights file to be loaded. Only to be specified if `perceptual_discriminator` is a keras application model. - metrics: list, optional - List of metrics to be evaluated on the generated images during + metrics: list, optional + List of metrics to be evaluated on the generated images during training and testing """ @@ -77,9 +77,7 @@ def __init__( if isinstance(perceptual_discriminator, str): self.perceptual_discriminator = tf.keras.Sequential( [ - layers.Lambda( - lambda img: layers.Concatenate(axis=-1)([img] * 3) - ), + layers.Lambda(lambda img: layers.Concatenate(axis=-1)([img] * 3)), getattr(tf.keras.applications, perceptual_discriminator)( include_top=False, weights=perceptual_discriminator_weights, @@ -93,8 +91,8 @@ def __init__( else: raise AttributeError( - 'Invalid model format. perceptual_discriminator must be either a string ' - 'indicating the name of the pre-trained model, or a keras model.' + "Invalid model format. perceptual_discriminator must be either a string " + "indicating the name of the pre-trained model, or a keras model." ) self.perceptual_discriminator.trainable = False @@ -164,9 +162,7 @@ def train_step(self, data): with tf.GradientTape() as tape: assemble_output = self.assemble(batch_x) - generated_image_copies = [assemble_output[2]] * ( - self.num_losses - 1 - ) + generated_image_copies = [assemble_output[2]] * (self.num_losses - 1) batch_y_copies = [batch_y] * (self.num_losses - 1) diff --git a/deeptrack/models/gnns/__init__.py b/deeptrack/models/gnns/__init__.py new file mode 100644 index 000000000..cf4f59d6c --- /dev/null +++ b/deeptrack/models/gnns/__init__.py @@ -0,0 +1 @@ +from .models import * \ No newline at end of file diff --git a/deeptrack/models/gnns/layers.py b/deeptrack/models/gnns/layers.py index 38dad9137..6b7a4ceed 100644 --- a/deeptrack/models/gnns/layers.py +++ b/deeptrack/models/gnns/layers.py @@ -1,9 +1,8 @@ import tensorflow as tf from tensorflow.keras import layers -from ..layers import * - -GraphDenseBlock = DenseBlock(activation=GELU, normalization="LayerNormalization") +from ..layers import MultiHeadGatedSelfAttention, MultiHeadSelfAttention, register +from ..utils import as_activation, as_normalization, single_layer_call, GELU class FGNN(tf.keras.layers.Layer): @@ -13,12 +12,18 @@ class FGNN(tf.keras.layers.Layer): ---------- filters : int Number of filters. - message_layer : str or callable - Message layer. - update_layer : str or callable - Update layer. + activation : str or activation function or layer + Activation function of the layer. See keras docs for accepted strings. + normalization : str or normalization function or layer + Normalization function of the layer. See keras and tfa docs for accepted strings. random_edge_dropout : float, optional Random edge dropout. + use_gates : bool, optional + Whether to use gated self-attention layers as update layer. Defaults to True. + att_layer_kwargs : dict, optional + Keyword arguments for the self-attention layer. + norm_kwargs : dict + Arguments for the normalization function. kwargs : dict Additional arguments. """ @@ -26,17 +31,37 @@ class FGNN(tf.keras.layers.Layer): def __init__( self, filters, - message_layer=GraphDenseBlock, - update_layer="MultiHeadGatedSelfAttention", + activation=GELU, + normalization="LayerNormalization", random_edge_dropout=False, + use_gates=True, + att_layer_kwargs={}, + norm_kwargs={}, **kwargs, ): super().__init__(**kwargs) self.filters = filters self.random_edge_dropout = random_edge_dropout - self.message_layer = as_block(message_layer)(filters) - self.update_layer = as_block(update_layer)(None) + # self.message_layer = as_block(message_layer)(filters) + self.message_layer = tf.keras.Sequential( + [ + layers.Dense(self.filters), + as_activation(activation), + as_normalization(normalization)(**norm_kwargs), + ] + ) + + _multi_head_att_layer = ( + MultiHeadGatedSelfAttention if use_gates else MultiHeadSelfAttention + ) + self.update_layer = tf.keras.Sequential( + [ + _multi_head_att_layer(**att_layer_kwargs), + as_activation(activation), + as_normalization(normalization)(**norm_kwargs), + ] + ) def build(self, input_shape): self.sigma = tf.Variable( @@ -132,31 +157,31 @@ def aggregate(_, x): return (updated_nodes, weighted_messages, distance, edges) -@register("FGnn") +@register("FGNN") def FGNNlayer( - message_layer=GraphDenseBlock, - update_layer="MultiHeadGatedSelfAttention", + activation=GELU, + normalization="LayerNormalization", random_edge_dropout=False, - activation=None, - normalization=None, + use_gates=True, + att_layer_kwargs={}, norm_kwargs={}, **kwargs, ): """Fingerprinting Graph Layer. Parameters ---------- - number_of_heads : int - Number of attention heads. - message_layer : str or callable - Message layer. - update_layer : str or callable - Update layer. - random_edge_dropout : float, optional - Random edge dropout. + filters : int + Number of filters. activation : str or activation function or layer Activation function of the layer. See keras docs for accepted strings. normalization : str or normalization function or layer Normalization function of the layer. See keras and tfa docs for accepted strings. + random_edge_dropout : float, optional + Random edge dropout. + use_gates : bool, optional + Whether to use gated self-attention layers as update layer. Defaults to True. + att_layer_kwargs : dict, optional + Keyword arguments for the self-attention layer. norm_kwargs : dict Arguments for the normalization function. kwargs : dict @@ -167,14 +192,15 @@ def Layer(filters, **kwargs_inner): kwargs_inner.update(kwargs) layer = FGNN( filters, - message_layer, - update_layer, + activation, + normalization, random_edge_dropout, + use_gates, + att_layer_kwargs, + norm_kwargs, **kwargs_inner, ) - return lambda x: single_layer_call( - x, layer, activation, normalization, norm_kwargs - ) + return lambda x: single_layer_call(x, layer, None, None, {}) return Layer @@ -186,24 +212,30 @@ class ClassTokenFGNN(FGNN): ---------- filters : int Number of filters. - message_layer : str or callable - Message layer. - update_layer : str or callable - Update layer. + activation : str or activation function or layer + Activation function of the layer. See keras docs for accepted strings. + normalization : str or normalization function or layer + Normalization function of the layer. See keras and tfa docs for accepted strings. random_edge_dropout : float, optional Random edge dropout. + use_gates : bool, optional + Whether to use gated self-attention layers as update layer. Defaults to True. + att_layer_kwargs : dict, optional + Keyword arguments for the self-attention layer. + norm_kwargs : dict + Arguments for the normalization function. kwargs : dict Additional arguments. """ def build(self, input_shape): + super().build(input_shape) self.combine_layer = tf.keras.Sequential( [ tf.keras.layers.Lambda(lambda x: tf.concat(x, axis=-1)), layers.Dense(self.filters), ] ) - super().build(input_shape) def call(self, inputs): nodes, edge_features, distance, edges = inputs @@ -287,12 +319,13 @@ def aggregate(_, x): return (updated_nodes, weighted_messages, distance, edges) +@register("CTFGNN") def ClassTokenFGNNlayer( - message_layer=GraphDenseBlock, - update_layer="MultiHeadGatedSelfAttention", + activation=GELU, + normalization="LayerNormalization", random_edge_dropout=False, - activation=None, - normalization=None, + use_gates=True, + att_layer_kwargs={}, norm_kwargs={}, **kwargs, ): @@ -321,13 +354,14 @@ def Layer(filters, **kwargs_inner): kwargs_inner.update(kwargs) layer = ClassTokenFGNN( filters, - message_layer, - update_layer, + activation, + normalization, random_edge_dropout, + use_gates, + att_layer_kwargs, + norm_kwargs, **kwargs_inner, ) - return lambda x: single_layer_call( - x, layer, activation, normalization, norm_kwargs - ) + return lambda x: single_layer_call(x, layer, None, None, {}) return Layer \ No newline at end of file diff --git a/deeptrack/models/gnns/models.py b/deeptrack/models/gnns/models.py new file mode 100644 index 000000000..6e3f03e50 --- /dev/null +++ b/deeptrack/models/gnns/models.py @@ -0,0 +1,327 @@ +import tensorflow as tf +from tensorflow.keras import layers + +from ..utils import KerasModel, GELU +from ..layers import as_block, DenseBlock + + +class MAGIK(KerasModel): + """ + Message passing graph neural network. + Parameters: + ----------- + dense_layer_dimensions: list of ints + List of the number of units in each dense layer of the encoder and decoder. The + number of layers is inferred from the length of this list. + base_layer_dimensions: list of ints + List of the latent dimensions of the graph blocks. The number of layers is + inferred from the length of this list. + number_of_node_outputs: int + Number of output node features. + number_of_edge_outputs: int + Number of output edge features. + node_output_activation: str + Activation function for the output node layer. + edge_output_activation: str + Activation function for the output edge layer. + dense_block: str, keras.layers.Layer, or callable + The dense block to use for the encoder and decoder. + graph_block: str, keras.layers.Layer, or callable + The graph block to use for the graph blocks. + output_type: str + Type of output. Either "nodes", "edges", or "graph". + If 'key' is not a supported output type, then the + model output will be the concatenation of the node + and edge predictions. + kwargs: dict + Keyword arguments for the dense block. + Returns: + -------- + tf.keras.Model + Keras model for the graph neural network. + """ + + def __init__( + self, + dense_layer_dimensions=(32, 64, 96), + base_layer_dimensions=(96, 96), + number_of_node_features=3, + number_of_edge_features=1, + number_of_node_outputs=1, + number_of_edge_outputs=1, + node_output_activation=None, + edge_output_activation=None, + dense_block=DenseBlock(activation=GELU, normalization="LayerNormalization"), + graph_block="FGNN", + output_type="graph", + **kwargs + ): + + dense_block = as_block(dense_block) + graph_block = as_block(graph_block) + + node_features, edge_features, edges, edge_weights = ( + tf.keras.Input(shape=(None, number_of_node_features)), + tf.keras.Input(shape=(None, number_of_edge_features)), + tf.keras.Input(shape=(None, 2), dtype=tf.int32), + tf.keras.Input(shape=(None, 2)), + ) + + node_layer = node_features + edge_layer = edge_features + + # Encoder for node and edge features + for dense_layer_number, dense_layer_dimension in zip( + range(len(dense_layer_dimensions)), dense_layer_dimensions + ): + node_layer = dense_block( + dense_layer_dimension, + name="node_ide" + str(dense_layer_number + 1), + **kwargs + )(node_layer) + + edge_layer = dense_block( + dense_layer_dimension, + name="edge_ide" + str(dense_layer_number + 1), + **kwargs + )(edge_layer) + + # Extract distance matrix + distance = edge_features[..., 0] + + # Bottleneck path, graph blocks + layer = (node_layer, edge_layer, distance, edges) + for base_layer_number, base_layer_dimension in zip( + range(len(base_layer_dimensions)), base_layer_dimensions + ): + layer = graph_block( + base_layer_dimension, + name="graph_block_" + str(base_layer_number), + )(layer) + + # Decoder for node and edge features + node_layer, edge_layer, *_ = layer + for dense_layer_number, dense_layer_dimension in zip( + range(len(dense_layer_dimensions)), + reversed(dense_layer_dimensions), + ): + node_layer = dense_block( + dense_layer_dimension, + name="node_idd" + str(dense_layer_number + 1), + **kwargs + )(node_layer) + + edge_layer = dense_block( + dense_layer_dimension, + name="edge_idd" + str(dense_layer_number + 1), + **kwargs + )(edge_layer) + + # Output layers + node_output = layers.Dense( + number_of_node_outputs, + activation=node_output_activation, + name="node_prediction", + )(node_layer) + + edge_output = layers.Dense( + number_of_edge_outputs, + activation=edge_output_activation, + name="edge_prediction", + )(edge_layer) + + output_dict = { + "nodes": node_output, + "edges": edge_output, + "graph": [node_output, edge_output], + } + try: + outputs = output_dict[output_type] + except KeyError: + outputs = output_dict["graph"] + + model = tf.keras.models.Model( + [node_features, edge_features, edges, edge_weights], + outputs, + ) + + super().__init__(model, **kwargs) + + +class CTMAGIK(KerasModel): + """ + Message passing graph neural network. + Parameters: + ----------- + dense_layer_dimensions: list of ints + List of the number of units in each dense layer of the encoder and decoder. The + number of layers is inferred from the length of this list. + base_layer_dimensions: list of ints + List of the latent dimensions of the graph blocks. The number of layers is + inferred from the length of this list. + number_of_node_outputs: int + Number of output node features. + number_of_edge_outputs: int + Number of output edge features. + number_of_global_outputs: int + Number of output global features. + node_output_activation: str or activation function or layer + Activation function for the output node layer. See keras docs for accepted strings. + edge_output_activation: str or activation function or layer + Activation function for the output edge layer. See keras docs for accepted strings. + cls_layer_dimension: int + Number of units in the decoder layer for global features. + global_output_activation: str or activation function or layer + Activation function for the output global layer. See keras docs for accepted strings. + dense_block: str, keras.layers.Layer, or callable + The dense block to use for the encoder and decoder. + graph_block: str, keras.layers.Layer, or callable + The graph block to use for the graph blocks. + classtokens_block: str, keras.layers.Layer, or callable + The embedding block to use for the class tokens. + output_type: str + Type of output. Either "nodes", "edges", "global" or + "graph". If 'key' is not a supported output type, then + the model output will be the concatenation of the node, + edge, and global predictions. + kwargs: dict + Keyword arguments for the dense block. + Returns: + -------- + tf.keras.Model + Keras model for the graph neural network. + """ + + def __init__( + self, + dense_layer_dimensions=(32, 64, 96), + base_layer_dimensions=(96, 96), + number_of_node_features=3, + number_of_edge_features=1, + number_of_node_outputs=1, + number_of_edge_outputs=1, + number_of_global_outputs=1, + node_output_activation=None, + edge_output_activation=None, + cls_layer_dimension=64, + global_output_activation=None, + dense_block=DenseBlock(activation=GELU, normalization="LayerNormalization"), + graph_block="CTFGNN", + classtoken_block="ClassToken", + output_type="graph", + **kwargs + ): + + dense_block = as_block(dense_block) + graph_block = as_block(graph_block) + classtoken_block = as_block(classtoken_block) + + node_features, edge_features, edges, edge_weights = ( + tf.keras.Input(shape=(None, number_of_node_features)), + tf.keras.Input(shape=(None, number_of_edge_features)), + tf.keras.Input(shape=(None, 2), dtype=tf.int32), + tf.keras.Input(shape=(None, 2)), + ) + + node_layer = node_features + edge_layer = edge_features + + # Encoder for node and edge features + for dense_layer_number, dense_layer_dimension in zip( + range(len(dense_layer_dimensions)), dense_layer_dimensions + ): + node_layer = dense_block( + dense_layer_dimension, + name="node_ide" + str(dense_layer_number + 1), + **kwargs + )(node_layer) + + edge_layer = dense_block( + dense_layer_dimension, + name="edge_ide" + str(dense_layer_number + 1), + **kwargs + )(edge_layer) + + # Extract distance matrix + distance = edge_features[..., 0] + + # Bottleneck path, graph blocks + layer = ( + classtoken_block(base_layer_dimensions, name="ClassTokenLayer")(node_layer), + edge_layer, + distance, + edges, + ) + for base_layer_number, base_layer_dimension in zip( + range(len(base_layer_dimensions)), base_layer_dimensions + ): + layer = graph_block( + base_layer_dimension, + name="graph_block_" + str(base_layer_number), + )(layer) + + # Decoder for node, edge, and global features + node_layer, edge_layer, *_ = layer + # Split node and global features + cls_layer, node_layer = ( + tf.keras.layers.Lambda(lambda x: x[:, 0], name="RetrieveClassToken")( + node_layer + ), + node_layer[:1:], + ) + for dense_layer_number, dense_layer_dimension in zip( + range(len(dense_layer_dimensions)), + reversed(dense_layer_dimensions), + ): + node_layer = dense_block( + dense_layer_dimension, + name="node_idd" + str(dense_layer_number + 1), + **kwargs + )(node_layer) + + edge_layer = dense_block( + dense_layer_dimension, + name="edge_idd" + str(dense_layer_number + 1), + **kwargs + )(edge_layer) + + cls_layer = dense_block(cls_layer_dimension, name="cls_mlp", **kwargs)( + cls_layer + ) + + # Output layers + node_output = layers.Dense( + number_of_node_outputs, + activation=node_output_activation, + name="node_prediction", + )(node_layer) + + edge_output = layers.Dense( + number_of_edge_outputs, + activation=edge_output_activation, + name="edge_prediction", + )(edge_layer) + + global_output = layers.Dense( + number_of_global_outputs, + activation=global_output_activation, + name="global_prediction", + )(cls_layer) + + output_dict = { + "nodes": node_output, + "edges": edge_output, + "global": global_output, + "graph": [node_output, edge_output, global_output], + } + try: + outputs = output_dict[output_type] + except KeyError: + outputs = output_dict["graph"] + + model = tf.keras.models.Model( + [node_features, edge_features, edges, edge_weights], + outputs, + ) + + super().__init__(model, **kwargs) diff --git a/deeptrack/models/layers.py b/deeptrack/models/layers.py index d62ce47c4..b2995998c 100644 --- a/deeptrack/models/layers.py +++ b/deeptrack/models/layers.py @@ -2,27 +2,16 @@ """ +from typing_extensions import Self from warnings import WarningMessage from tensorflow.keras import layers import tensorflow as tf -try: - import tensorflow_addons as tfa +from .utils import single_layer_call, as_activation, as_normalization, GELU - InstanceNormalization = tfa.layers.InstanceNormalization - GELU = layers.Lambda(lambda x: tfa.activations.gelu(x, approximate=False)) -except Exception: - import warnings +from functools import reduce - InstanceNormalization, GELU = (layers.Layer(),) * 2 - warnings.warn( - "DeepTrack not installed with tensorflow addons. Instance normalization and GELU activation will not work. Consider upgrading to tensorflow >= 2.0.", - ImportWarning, - ) - -import pkg_resources - -installed_pkg = [pkg.key for pkg in pkg_resources.working_set] +import warnings BLOCKS = {} @@ -60,52 +49,6 @@ def as_block(x): return x -def _as_activation(x): - if x is None: - return layers.Layer() - elif isinstance(x, str): - return layers.Activation(x) - elif isinstance(x, layers.Layer): - return x - else: - return layers.Layer(x) - - -def _get_norm_by_name(x): - if hasattr(layers, x): - return getattr(layers, x) - elif "tensorflow-addons" in installed_pkg and hasattr(tfa.layers, x): - return getattr(tfa.layers, x) - else: - raise ValueError(f"Unknown normalization {x}.") - - -def _as_normalization(x): - if x is None: - return layers.Layer() - elif isinstance(x, str): - return _get_norm_by_name(x) - elif isinstance(x, layers.Layer) or callable(x): - return x - else: - return layers.Layer(x) - - -def single_layer_call(x, layer, activation, normalization, norm_kwargs): - assert isinstance(norm_kwargs, dict), "norm_kwargs must be a dict. Got {0}".format( - type(norm_kwargs) - ) - y = layer(x) - - if activation: - y = _as_activation(activation)(y) - - if normalization: - y = _as_normalization(normalization)(**norm_kwargs)(y) - - return y - - @register("convolutional", "conv") def ConvolutionalBlock( kernel_size=3, @@ -380,7 +323,7 @@ def call(x): y = single_layer_call(y, conv2, None, normalization, norm_kwargs) y = layers.Add()([identity(x), y]) if activation: - y = _as_activation(activation)(y) + y = as_activation(activation)(y) return y return call @@ -392,7 +335,7 @@ def call(x): def Identity(activation=None, normalization=False, norm_kwargs={}, **kwargs): """Identity layer that returns the input tensor. - Can optionally perform instance normalization or some activation function. + Can optionally perform normalization or some activation function. Accepts arguments of keras.layers.Layer. @@ -424,14 +367,25 @@ class MultiHeadSelfAttention(layers.Layer): ---------- number_of_heads : int Number of attention heads. + use_bias : bool + Whether to use bias in attention layer. + return_attention_weights : bool + Whether to return the attention weights for visualization. kwargs Other arguments for the keras.layers.Layer """ - def __init__(self, number_of_heads, use_bias=True, **kwargs): + def __init__( + self, + number_of_heads=12, + use_bias=True, + return_attention_weights=False, + **kwargs, + ): super().__init__(**kwargs) self.number_of_heads = number_of_heads self.use_bias = use_bias + self.return_attention_weights = return_attention_weights def build(self, input_shape): try: @@ -449,6 +403,7 @@ def build(self, input_shape): self.query_dense = layers.Dense(filters, use_bias=self.use_bias) self.key_dense = layers.Dense(filters, use_bias=self.use_bias) self.value_dense = layers.Dense(filters, use_bias=self.use_bias) + self.combine_dense = layers.Dense(filters, use_bias=self.use_bias) def SingleAttention(self, query, key, value, gate=None, **kwargs): @@ -472,7 +427,7 @@ def SingleAttention(self, query, key, value, gate=None, **kwargs): weights = tf.nn.softmax(scaled_score, axis=-1) output = tf.matmul(weights, value) - if gate: + if gate is not None: output = tf.math.multiply(output, gate) return output, weights @@ -526,20 +481,42 @@ def call(self, x, **kwargs): x : tuple of tf.Tensors Input tensors. """ - (attention, _), batch_size = self.compute_attention(x, **kwargs) + (attention, weights), batch_size = self.compute_attention(x, **kwargs) attention = tf.transpose(attention, perm=[0, 2, 1, 3]) concat_attention = tf.reshape(attention, (batch_size, -1, self.filters)) output = self.combine_dense(concat_attention) - return output + if self.return_attention_weights: + return output, weights + else: + return output class MultiHeadGatedSelfAttention(MultiHeadSelfAttention): def build(self, input_shape): - super().build(input_shape) - self.gate_dense = layers.Dense(self.filters, activation="sigmoid") + """ + Build the layer. + """ + try: + filters = input_shape[1][-1] + except TypeError: + filters = input_shape[-1] + + if filters % self.number_of_heads != 0: + raise ValueError( + f"embedding dimension = {filters} should be divisible by number of heads = {self.number_of_heads}" + ) + self.filters = filters + self.projection_dim = filters // self.number_of_heads + + self.query_dense = layers.Dense(filters) + self.key_dense = layers.Dense(filters) + self.value_dense = layers.Dense(filters) + self.gate_dense = layers.Dense(filters, activation="sigmoid") + + self.combine_dense = layers.Dense(filters) - def compute_gated_attention(self, x, **kwargs): + def compute_attention(self, x, **kwargs): """ Compute attention. Parameters @@ -575,18 +552,26 @@ def compute_gated_attention(self, x, **kwargs): def MultiHeadSelfAttentionLayer( number_of_heads=12, use_bias=True, - activation=GELU, + return_attention_weights=False, + activation="relu", normalization="LayerNormalization", norm_kwargs={}, **kwargs, ): """Multi-head self-attention layer. + + Can optionally perform normalization or some activation function. + + Accepts arguments of keras.layers.Layer. + Parameters ---------- number_of_heads : int Number of attention heads. use_bias : bool Whether to use bias in the dense layers. + return_attention_weights : bool + Whether to return attention weights for visualization. activation : str or activation function or layer Activation function of the layer. See keras docs for accepted strings. normalization : str or normalization function or layer @@ -599,7 +584,9 @@ def MultiHeadSelfAttentionLayer( def Layer(filters, **kwargs_inner): kwargs_inner.update(kwargs) - layer = MultiHeadSelfAttention(number_of_heads, use_bias, **kwargs_inner) + layer = MultiHeadSelfAttention( + number_of_heads, use_bias, return_attention_weights, **kwargs_inner + ) return lambda x: single_layer_call( x, layer, activation, normalization, norm_kwargs ) @@ -611,18 +598,26 @@ def Layer(filters, **kwargs_inner): def MultiHeadGatedSelfAttentionLayer( number_of_heads=12, use_bias=True, - activation=GELU, + return_attention_weights=False, + activation="relu", normalization="LayerNormalization", norm_kwargs={}, **kwargs, ): """Multi-head gated self-attention layer. + + Can optionally perform normalization or some activation function. + + Accepts arguments of keras.layers.Layer. + Parameters ---------- number_of_heads : int Number of attention heads. use_bias : bool Whether to use bias in the dense layers. + return_attention_weights : bool + Whether to return attention weights for visualization. activation : str or activation function or layer Activation function of the layer. See keras docs for accepted strings. normalization : str or normalization function or layer @@ -635,7 +630,9 @@ def MultiHeadGatedSelfAttentionLayer( def Layer(filters, **kwargs_inner): kwargs_inner.update(kwargs) - layer = MultiHeadGatedSelfAttention(number_of_heads, use_bias, **kwargs_inner) + layer = MultiHeadGatedSelfAttention( + number_of_heads, use_bias, return_attention_weights, **kwargs_inner + ) return lambda x: single_layer_call( x, layer, activation, normalization, norm_kwargs ) @@ -643,52 +640,137 @@ def Layer(filters, **kwargs_inner): return Layer -class ClassToken(tf.keras.layers.Layer): - """ClassToken Layer.""" +class TransformerEncoder(tf.keras.layers.Layer): + """Transformer Encoder. + Parameters + ---------- + fwd_mlp_dim : int + Dimension of the forward MLP. + number_of_heads : int + Number of attention heads. + dropout : float + Dropout rate. + activation : str or activation function or layer + Activation function of the layer. See keras docs for accepted strings. + normalization : str or normalization function or layer + Normalization function of the layer. See keras and tfa docs for accepted strings. + use_gates : bool, optional + Whether to use gated self-attention layers as update layer. Defaults to False. + use_bias: bool, optional + Whether to use bias in the dense layers of the attention layers. Defaults to False. + norm_kwargs : dict + Arguments for the normalization function. + kwargs : dict + Additional arguments. + """ - def build(self, input_shape): - cls_init = tf.zeros_initializer() - self.hidden_size = input_shape[-1] - self.cls = tf.Variable( - name="cls", - initial_value=cls_init(shape=(1, 1, self.hidden_size), dtype="float32"), - trainable=True, - ) + def __init__( + self, + fwd_mlp_dim, + number_of_heads=12, + dropout=0.0, + activation=GELU, + normalization="LayerNormalization", + use_gates=False, + use_bias=False, + norm_kwargs={}, + **kwargs, + ): + super().__init__(**kwargs) + self.number_of_heads = number_of_heads + self.use_bias = use_bias + self.use_gates = use_gates + + self.fwd_mlp_dim = fwd_mlp_dim + self.dropout = dropout + + self.activation = activation - def call(self, inputs): - batch_size = tf.shape(inputs)[0] - cls_broadcasted = tf.cast( - tf.broadcast_to(self.cls, [batch_size, 1, self.hidden_size]), - dtype=inputs.dtype, + self.normalization = normalization + + self.MultiHeadAttLayer = ( + MultiHeadGatedSelfAttention if self.use_gates else MultiHeadSelfAttention + )( + number_of_heads=self.number_of_heads, + use_bias=self.use_bias, + return_attention_weights=True, + name="MultiHeadAttLayer", + ) + self.norm_0, self.norm_1 = ( + as_normalization(normalization)(**norm_kwargs), + as_normalization(normalization)(**norm_kwargs), ) - return tf.concat([cls_broadcasted, inputs], 1) + self.dropout_layer = tf.keras.layers.Dropout(self.dropout) + def build(self, input_shape): + self.feed_forward_layer = tf.keras.Sequential( + [ + layers.Dense( + self.fwd_mlp_dim, + name=f"{self.name}/Dense_0", + ), + as_activation(self.activation), + layers.Dropout(self.dropout), + layers.Dense(input_shape[-1], name=f"{self.name}/Dense_1"), + layers.Dropout(self.dropout), + ], + name="feed_forward", + ) -def ClassTokenLayer(activation=None, normalization=None, norm_kwargs={}, **kwargs): - """ClassToken Layer that append a class token to the input. + def call(self, inputs, training): + x, weights = self.MultiHeadAttLayer(inputs) + x = self.dropout_layer(x, training=training) + x = self.norm_0(inputs + x) - Can optionally perform instance normalization or some activation function. + y = self.feed_forward_layer(x) + return self.norm_1(x + y), weights - Accepts arguments of keras.layers.Layer. +@register("TransformerEncoder") +def TransformerEncoderLayer( + number_of_heads=12, + dropout=0.0, + activation=GELU, + normalization="LayerNormalization", + use_gates=False, + use_bias=False, + norm_kwargs={}, + **kwargs, +): + """Transformer Encoder Layer. Parameters ---------- - + number_of_heads : int + Number of attention heads. + dropout : float + Dropout rate. activation : str or activation function or layer Activation function of the layer. See keras docs for accepted strings. normalization : str or normalization function or layer Normalization function of the layer. See keras and tfa docs for accepted strings. + use_gates : bool, optional + Whether to use gated self-attention layers as update layer. Defaults to False. + use_bias: bool, optional + Whether to use bias in the dense layers of the attention layers. Defaults to True. norm_kwargs : dict Arguments for the normalization function. - **kwargs - Other keras.layers.Layer arguments + kwargs : dict + Additional arguments. """ def Layer(filters, **kwargs_inner): kwargs_inner.update(kwargs) - layer = ClassToken(**kwargs_inner) - return lambda x: single_layer_call( - x, layer, activation, normalization, norm_kwargs + layer = TransformerEncoder( + filters, + number_of_heads, + dropout, + activation, + normalization, + use_gates, + use_bias, + norm_kwargs, + **kwargs, ) + return lambda x: single_layer_call(x, layer, None, None, {}) return Layer diff --git a/deeptrack/models/utils.py b/deeptrack/models/utils.py index a010db9d9..488e58e9d 100644 --- a/deeptrack/models/utils.py +++ b/deeptrack/models/utils.py @@ -1,4 +1,4 @@ -from functools import wraps +from functools import wraps, reduce import numpy as np from tensorflow.keras import layers, models @@ -6,7 +6,37 @@ from .. import features from ..generators import ContinuousGenerator -__all__ = ["compile", "load_model", "Model", "KerasModel", "LoadModel"] + +try: + import tensorflow_addons as tfa + + InstanceNormalization = tfa.layers.InstanceNormalization + GELU = layers.Lambda(lambda x: tfa.activations.gelu(x, approximate=False)) +except Exception: + import warnings + + InstanceNormalization, GELU = (layers.Layer(),) * 2 + warnings.warn( + "DeepTrack not installed with tensorflow addons. Instance normalization and GELU activation will not work. Consider upgrading to tensorflow >= 2.0.", + ImportWarning, + ) + +import pkg_resources + +installed_pkg = [pkg.key for pkg in pkg_resources.working_set] + +__all__ = [ + "compile", + "load_model", + "Model", + "KerasModel", + "LoadModel", + "single_layer_call", + "as_activation", + "as_normalization", + "GELU", + "InstanceNormalization", +] def compile(model: models.Model, *, loss="mae", optimizer="adam", metrics=[], **kwargs): @@ -50,6 +80,55 @@ def LoadModel(path, compile_from_file=False, custom_objects={}, **kwargs): load_model = LoadModel +def as_activation(x): + if x is None: + return layers.Layer() + elif isinstance(x, str): + return layers.Activation(x) + elif isinstance(x, layers.Layer): + return x + else: + return layers.Layer(x) + + +def _get_norm_by_name(x): + if hasattr(layers, x): + return getattr(layers, x) + elif "tensorflow-addons" in installed_pkg and hasattr(tfa.layers, x): + return getattr(tfa.layers, x) + else: + raise ValueError(f"Unknown normalization {x}.") + + +def as_normalization(x): + if x is None: + return layers.Layer() + elif isinstance(x, str): + return _get_norm_by_name(x) + elif isinstance(x, layers.Layer) or callable(x): + return x + else: + return layers.Layer(x) + + +def single_layer_call( + x, layer, activation, normalization, norm_kwargs, activation_first=True +): + assert isinstance(norm_kwargs, dict), "norm_kwargs must be a dict. Got {0}".format( + type(norm_kwargs) + ) + + n = ( + lambda x: as_normalization(normalization)(**norm_kwargs)(x) + if normalization + else x + ) + a = lambda x: as_activation(activation)(x) if activation else x + fs = [layer, a, n] if activation_first else [layer, n, a] + + return reduce(lambda x, f: f(x), fs, x) + + def with_citation(citation): def wrapper(func): @wraps(func) @@ -117,7 +196,7 @@ def __init__( metrics=[], compile=True, add_batch_dimension_on_resolve=True, - **kwargs + **kwargs, ): if compile: @@ -127,7 +206,7 @@ def __init__( model, add_batch_dimension_on_resolve=add_batch_dimension_on_resolve, metrics=metrics, - **kwargs + **kwargs, ) @wraps(models.Model.fit) @@ -142,7 +221,7 @@ def fit(self, x, *args, batch_size=32, generator_kwargs={}, **kwargs): "max_data_size": batch_size * 50, }, **generator_kwargs, - } + }, ) with generator: h = self.model.fit(generator, *args, batch_size=batch_size, **kwargs) diff --git a/deeptrack/test/test_layers.py b/deeptrack/test/test_layers.py index 95f66757d..39f09170a 100644 --- a/deeptrack/test/test_layers.py +++ b/deeptrack/test/test_layers.py @@ -126,9 +126,9 @@ def test_Multi_Head_Gated_Attention_filters(self): self.assertEqual(model.layers[1].filters, 96) def test_FGNN_layer(self): - layer = layers.FGNNlayer() + block = layers.FGNNlayer() model = makeMinimalModel( - layer(96), + block(96), input_layer=( k_layers.Input(shape=(None, 96)), k_layers.Input(shape=(None, 10)), @@ -139,9 +139,9 @@ def test_FGNN_layer(self): self.assertTrue(model.layers[-1], layers.FGNN) def test_Class_Token_FGNN_layer(self): - layer = layers.ClassTokenFGNNlayer() + block = layers.ClassTokenFGNNlayer() model = makeMinimalModel( - layer(96), + block(96), input_layer=( k_layers.Input(shape=(None, 96)), k_layers.Input(shape=(None, 10)), @@ -151,10 +151,10 @@ def test_Class_Token_FGNN_layer(self): ) self.assertTrue(model.layers[-1], layers.ClassTokenFGNN) - def test_Class_Token_FGNN_message_layer(self): - layer = layers.ClassTokenFGNNlayer() + def test_Class_Token_FGNN_update_layer(self): + block = layers.ClassTokenFGNNlayer(att_layer_kwargs={"number_of_heads": 6}) model = makeMinimalModel( - layer(96), + block(96), input_layer=( k_layers.Input(shape=(None, 96)), k_layers.Input(shape=(None, 10)), @@ -162,14 +162,15 @@ def test_Class_Token_FGNN_message_layer(self): k_layers.Input(shape=(None, 2), dtype=tf.int32), ), ) - self.assertTrue(model.layers[-1].message_layer, layers.DenseBlock) + self.assertEqual(model.layers[-1].update_layer.layers[0].number_of_heads, 6) - def test_Class_Token_FGNN_update_layer(self): - layer = layers.ClassTokenFGNNlayer( - update_layer=layers.MultiHeadSelfAttentionLayer() + def test_Class_Token_FGNN_normalization(self): + # By setting center=False, scale=False, the number of trainable parameters should be 0 + block = layers.ClassTokenFGNNlayer( + norm_kwargs={"center": False, "scale": False, "axis": -1} ) model = makeMinimalModel( - layer(96), + block(96), input_layer=( k_layers.Input(shape=(None, 96)), k_layers.Input(shape=(None, 10)), @@ -177,7 +178,21 @@ def test_Class_Token_FGNN_update_layer(self): k_layers.Input(shape=(None, 2), dtype=tf.int32), ), ) - self.assertTrue(model.layers[-1].message_layer, layers.MultiHeadSelfAttention) + self.assertEqual(model.layers[-1].update_layer.layers[-1].count_params(), 0) + + def test_Transformer_Encoder(self): + block = layers.TransformerEncoderLayer() + model = makeMinimalModel(block(300), shape=(50, 300)) + self.assertTrue(model.layers[-1], layers.TransformerEncoder) + + def test_Tranformer_Encoder_parameters(self): + block = layers.TransformerEncoderLayer(number_of_heads=6) + model = makeMinimalModel(block(300), shape=(50, 300)) + + def test_Transformer_Encoder_bias(self): + block = layers.TransformerEncoderLayer(use_bias=True) + model = makeMinimalModel(block(300), shape=(50, 300)) + self.assertTrue(model.layers[-1].MultiHeadAttLayer.key_dense.use_bias, True) if __name__ == "__main__":