Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deeptrack/backend/pint_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@
reciprocal_centimeter = 1 / cm = cm_1 = kayser

# Velocity
[velocity] = [length] / [time] = [speed]
[velocity] = [length] / [time]
[speed] = [velocity]
knot = nautical_mile / hour = kt = knot_international = international_knot
mile_per_hour = mile / hour = mph = MPH
kilometer_per_hour = kilometer / hour = kph = KPH
Expand Down
23 changes: 16 additions & 7 deletions deeptrack/models/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,16 @@ def __init__(
)(layer, **transformer_input_kwargs)

# Extract global representation
cls_rep = layers.Lambda(lambda x: x[:, 0], name="RetrieveClassToken")(layer)
cls_rep = layers.Lambda(lambda x: x[:, 0], name="RetrieveClassToken")(
layer
)

# Process cls features
cls_layer = cls_rep
if cls_layer_dimension is not None:
cls_layer = dense_block(cls_layer_dimension, name="cls_mlp")(cls_layer)
cls_layer = dense_block(cls_layer_dimension, name="cls_mlp")(
cls_layer
)

cls_output = layers.Dense(
number_of_cls_outputs,
Expand Down Expand Up @@ -686,15 +690,20 @@ def __init__(
norm_kwargs={"epsilon": 1e-6},
),
positional_embedding_block=Identity(),
use_transformer_mask=False,
**kwargs,
):

dense_block = as_block(dense_block)

transformer_input, transformer_mask = (
layers.Input(shape=(None, number_of_node_features)),
layers.Input(shape=(None, 2), dtype="int32"),
)
transformer_input = layers.Input(shape=(None, number_of_node_features))
Inputs = [transformer_input]

if use_transformer_mask:
transformer_mask = layers.Input(shape=(None, 2), dtype="int32")
Inputs.append(transformer_mask)
else:
transformer_mask = None

layer = transformer_input
# Encoder for input features
Expand Down Expand Up @@ -735,7 +744,7 @@ def __init__(
)(layer)

model = models.Model(
[transformer_input, transformer_mask],
Inputs,
output_layer,
)

Expand Down
58 changes: 42 additions & 16 deletions deeptrack/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,16 @@ class TransformerEncoder(tf.keras.layers.Layer):
Activation function of the layer. See keras docs for accepted strings.
normalization : str or normalization function or layer
Normalization function of the layer. See keras and tfa docs for accepted strings.
use_gates : bool, optional
use_gates : bool, optional [Deprecated]
Whether to use gated self-attention layers as update layer. Defaults to False.
use_bias: bool, optional
Whether to use bias in the dense layers of the attention layers. Defaults to False.
norm_kwargs : dict
Arguments for the normalization function.
multi_head_attention_layer : tf.keras.layers.Layer
Layer to use for the multi-head attention. Defaults to dt.layers.MultiHeadSelfAttention.
multi_head_attention_kwargs : dict
Arguments for the multi-head attention layer.
kwargs : dict
Additional arguments.
"""
Expand All @@ -750,6 +754,10 @@ def __init__(
use_gates=False,
use_bias=False,
norm_kwargs={},
multi_head_attention_layer: layers.Layer = None,
multi_head_attention_kwargs={},
fwd_mlp_layer: layers.Layer = None,
fwd_mlp_kwargs={},
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -764,16 +772,37 @@ def __init__(

self.normalization = normalization

self.MultiHeadAttLayer = (
MultiHeadGatedSelfAttention
if self.use_gates
else MultiHeadSelfAttention
)(
number_of_heads=self.number_of_heads,
use_bias=self.use_bias,
return_attention_weights=True,
name="MultiHeadAttLayer",
)
if multi_head_attention_layer is None:
# Raise deprecation warning
warnings.warn(
"The use_gates argument is deprecated and will be removed in a future version. "
"Please use the multi_head_attention_layer argument instead.",
DeprecationWarning,
)

self.MultiHeadAttLayer = (
MultiHeadGatedSelfAttention
if self.use_gates
else MultiHeadSelfAttention
)(
number_of_heads=self.number_of_heads,
use_bias=self.use_bias,
return_attention_weights=True,
name="MultiHeadAttLayer",
)
else:
self.MultiHeadAttLayer = multi_head_attention_layer(
**multi_head_attention_kwargs
)

if fwd_mlp_layer is None:
self.FwdMlpLayer = layers.Dense(
self.fwd_mlp_dim,
name=f"{self.name}/Dense_0",
)
else:
self.FwdMlpLayer = fwd_mlp_layer(**fwd_mlp_kwargs)

self.norm_0, self.norm_1 = (
as_normalization(normalization)(**norm_kwargs),
as_normalization(normalization)(**norm_kwargs),
Expand All @@ -783,10 +812,7 @@ def __init__(
def build(self, input_shape):
self.feed_forward_layer = tf.keras.Sequential(
[
layers.Dense(
self.fwd_mlp_dim,
name=f"{self.name}/Dense_0",
),
self.FwdMlpLayer,
as_activation(self.activation),
layers.Dropout(self.dropout),
layers.Dense(input_shape[-1], name=f"{self.name}/Dense_1"),
Expand Down Expand Up @@ -827,7 +853,7 @@ def TransformerEncoderLayer(
Activation function of the layer. See keras docs for accepted strings.
normalization : str or normalization function or layer
Normalization function of the layer. See keras and tfa docs for accepted strings.
use_gates : bool, optional
use_gates : bool, optional [Deprecated]
Whether to use gated self-attention layers as update layer. Defaults to False.
use_bias: bool, optional
Whether to use bias in the dense layers of the attention layers. Defaults to True.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tensorflow-probability
tensorflow-datasets
pydeepimagej
more_itertools
pint
pint<0.20
pandas
tqdm