Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeptrack/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# For backwards compatability
from .models.layers import *
from .models.embeddings import *
from .models.gnns.layers import *
148 changes: 148 additions & 0 deletions deeptrack/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import tensorflow as tf

from .utils import single_layer_call
from .layers import register


class ClassToken(tf.keras.layers.Layer):
"""ClassToken Layer."""

def build(self, input_shape):
cls_init = tf.zeros_initializer()
self.hidden_size = input_shape[-1]
self.cls = tf.Variable(
name="cls",
initial_value=cls_init(shape=(1, 1, self.hidden_size), dtype="float32"),
trainable=True,
)

def call(self, inputs):
batch_size = tf.shape(inputs)[0]
cls_broadcasted = tf.cast(
tf.broadcast_to(self.cls, [batch_size, 1, self.hidden_size]),
dtype=inputs.dtype,
)
return tf.concat([cls_broadcasted, inputs], 1)


@register("ClassToken")
def ClassTokenLayer(activation=None, normalization=None, norm_kwargs={}, **kwargs):
"""ClassToken Layer that append a class token to the input.

Can optionally perform normalization or some activation function.

Accepts arguments of keras.layers.Layer.

Parameters
----------

activation : str or activation function or layer
Activation function of the layer. See keras docs for accepted strings.
normalization : str or normalization function or layer
Normalization function of the layer. See keras and tfa docs for accepted strings.
norm_kwargs : dict
Arguments for the normalization function.
**kwargs
Other keras.layers.Layer arguments
"""

def Layer(filters, **kwargs_inner):
kwargs_inner.update(kwargs)
layer = ClassToken(**kwargs_inner)
return lambda x: single_layer_call(
x, layer, activation, normalization, norm_kwargs
)

return Layer


class LearnablePositionEmbs(tf.keras.layers.Layer):
"""Adds or concatenates positional embeddings to the inputs.
Parameters
----------
initializer : str or tf.keras.initializers.Initializer
Initializer function for the embeddings. See tf.keras.initializers.Initializer for accepted functions.
concat : bool
Whether to concatenate the positional embeddings to the inputs. If False,
adds the positional embeddings to the inputs.
kwargs: dict
Other arguments for the keras.layers.Layer
"""

def __init__(
self,
initializer=tf.keras.initializers.RandomNormal(stddev=0.06),
concat=False,
**kwargs,
):
super().__init__(**kwargs)
self.concat = concat

assert initializer is callable or isinstance(
initializer, tf.keras.initializers.Initializer
), "initial_value must be callable or a tf.keras.initializers.Initializer"
self.initializer = initializer

def build(self, input_shape):
assert (
len(input_shape) == 3
), f"Number of dimensions should be 3, got {len(input_shape)}"
self.pos_embedding = tf.Variable(
name="pos_embedding",
initial_value=self.initializer(shape=(1, *(input_shape[-2:]))),
dtype="float32",
trainable=True,
)

def call(self, inputs):
if self.concat:
return tf.concat(
[inputs, tf.cast(self.pos_embedding, dtype=inputs.dtype)], axis=-1
)
else:
return inputs + tf.cast(self.pos_embedding, dtype=inputs.dtype)


@register("LearnablePositionEmbs")
def LearnablePositionEmbsLayer(
initializer=tf.keras.initializers.RandomNormal(stddev=0.06),
concat=False,
activation=None,
normalization=None,
norm_kwargs={},
**kwargs,
):
"""Adds or concatenates positional embeddings to the inputs.

Can optionally perform normalization or some activation function.

Accepts arguments of keras.layers.Layer.

Parameters
----------

initializer : str or tf.keras.initializers.Initializer
Initializer function for the embeddings. See tf.keras.initializers.Initializer for accepted functions.
concat : bool
Whether to concatenate the positional embeddings to the inputs. If False,
adds the positional embeddings to the inputs.
activation : str or activation function or layer
Activation function of the layer. See keras docs for accepted strings.
normalization : str or normalization function or layer
Normalization function of the layer. See keras and tfa docs for accepted strings.
norm_kwargs : dict
Arguments for the normalization function.
**kwargs
Other arguments for the keras.layers.Layer
"""

def Layer(filters, **kwargs_inner):
kwargs_inner.update(kwargs)
layer = LearnablePositionEmbs(
initializer=initializer, concat=concat, **kwargs_inner
)
return lambda x: single_layer_call(
x, layer, activation, normalization, norm_kwargs
)

return Layer
2 changes: 1 addition & 1 deletion deeptrack/models/gans/cgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from .utils import as_KerasModel
from ..utils import as_KerasModel

layers = tf.keras.layers

Expand Down
2 changes: 1 addition & 1 deletion deeptrack/models/gans/gan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from .utils import as_KerasModel
from ..utils import as_KerasModel

layers = tf.keras.layers

Expand Down
28 changes: 12 additions & 16 deletions deeptrack/models/gans/pcgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from .utils import as_KerasModel
from ..utils import as_KerasModel

layers = tf.keras.layers

Expand All @@ -18,20 +18,20 @@ class PCGAN(tf.keras.Model):
discriminator_loss: str or keras loss function
The loss function of the discriminator network
discriminator_optimizer: str or keras optimizer
The optimizer of the discriminator network
The optimizer of the discriminator network
discriminator_metrics: list, optional
List of metrics to be evaluated by the discriminator
model during training and testing
model during training and testing
assemble_loss: list of str or keras loss functions
List of loss functions to be evaluated on each output
of the assemble model (stacked generator and discriminator),
such as `assemble_loss = ["mse", "mse", "mae"]` for
the prediction of the discriminator, the predicted
perceptual features, and the generated image, respectively
perceptual features, and the generated image, respectively
assemble_optimizer: str or keras optimizer
The optimizer of the assemble network
The optimizer of the assemble network
assemble_loss_weights: list or dict, optional
List or dictionary specifying scalar coefficients (floats)
List or dictionary specifying scalar coefficients (floats)
to weight the loss contributions of the assemble model outputs
perceptual_discriminator: str or keras model
Name of the perceptual discriminator. Select the name of this network
Expand All @@ -43,8 +43,8 @@ class PCGAN(tf.keras.Model):
ImageNet weights, or provide the path to the weights file to be loaded.
Only to be specified if `perceptual_discriminator` is a keras application
model.
metrics: list, optional
List of metrics to be evaluated on the generated images during
metrics: list, optional
List of metrics to be evaluated on the generated images during
training and testing
"""

Expand Down Expand Up @@ -77,9 +77,7 @@ def __init__(
if isinstance(perceptual_discriminator, str):
self.perceptual_discriminator = tf.keras.Sequential(
[
layers.Lambda(
lambda img: layers.Concatenate(axis=-1)([img] * 3)
),
layers.Lambda(lambda img: layers.Concatenate(axis=-1)([img] * 3)),
getattr(tf.keras.applications, perceptual_discriminator)(
include_top=False,
weights=perceptual_discriminator_weights,
Expand All @@ -93,8 +91,8 @@ def __init__(

else:
raise AttributeError(
'Invalid model format. perceptual_discriminator must be either a string '
'indicating the name of the pre-trained model, or a keras model.'
"Invalid model format. perceptual_discriminator must be either a string "
"indicating the name of the pre-trained model, or a keras model."
)

self.perceptual_discriminator.trainable = False
Expand Down Expand Up @@ -164,9 +162,7 @@ def train_step(self, data):
with tf.GradientTape() as tape:
assemble_output = self.assemble(batch_x)

generated_image_copies = [assemble_output[2]] * (
self.num_losses - 1
)
generated_image_copies = [assemble_output[2]] * (self.num_losses - 1)

batch_y_copies = [batch_y] * (self.num_losses - 1)

Expand Down
1 change: 1 addition & 0 deletions deeptrack/models/gnns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .models import *
Loading