Skip to content
1 change: 1 addition & 0 deletions deeptrack/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .gans import *
from .gnns import *
from .vaes import *
from .waes import *

# from .mrcnn import *
# from .yolov1 import *
Expand Down
45 changes: 33 additions & 12 deletions deeptrack/models/gans/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@

layers = tf.keras.layers


@as_KerasModel
class GAN(tf.keras.Model):
"""Generative Adversarial Network (GAN) model.

Parameters:
discriminator: keras model, optional
The discriminator network.
generator: keras model, optional
The generator network.
latent_dim: int, optional
Dimension of the latent space for random vectors.
"""

def __init__(self, discriminator=None, generator=None, latent_dim=128):
super(GAN, self).__init__()

# Initialize discriminator and generator, or use default if not provided
if discriminator is None:
discriminator = self.default_discriminator()

if generator is None:
generator = self.default_generator()

Expand All @@ -21,9 +30,13 @@ def __init__(self, discriminator=None, generator=None, latent_dim=128):

def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()

# Set optimizers and loss function for training
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn

# Define metrics to track during training
self.d_loss_metric = tf.keras.metrics.Mean(name="d_loss")
self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss")

Expand All @@ -36,17 +49,18 @@ def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

# Decode them to fake images
# Generate fake images using the generator
generated_images = self.generator(random_latent_vectors)

# Combine them with real images
# Combine real and fake images
combined_images = tf.concat([generated_images, real_images], axis=0)

# Assemble labels discriminating real from fake images
# Create labels for real and fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!

# Add random noise to labels to improve stability
labels += 0.05 * tf.random.uniform(tf.shape(labels))

# Train the discriminator
Expand All @@ -58,14 +72,13 @@ def train_step(self, real_images):
zip(grads, self.discriminator.trainable_weights)
)

# Sample random points in the latent space
# Generate new random latent vectors
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

# Assemble labels that say "all real images"
# Create labels indicating "all real images" for generator training
misleading_labels = tf.zeros((batch_size, 1))

# Train the generator (note that we should *not* update the weights
# of the discriminator)!
# Train the generator while keeping discriminator weights fixed
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
Expand All @@ -75,12 +88,19 @@ def train_step(self, real_images):
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)

# Return updated loss metrics
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}

def call(self, inputs):
# Run generator
return self.generator(inputs)

def default_generator(self, latent_dim=128):
# Define the default generator architecture
return tf.keras.Sequential(
[
tf.keras.Input(shape=(latent_dim,)),
Expand All @@ -98,6 +118,7 @@ def default_generator(self, latent_dim=128):
)

def default_discriminator(self):
# Define the default discriminator architecture
return tf.keras.Sequential(
[
tf.keras.Input(shape=(64, 64, 3)),
Expand All @@ -112,4 +133,4 @@ def default_discriminator(self):
layers.Dense(1, activation="sigmoid"),
],
name="discriminator",
)
)
51 changes: 28 additions & 23 deletions deeptrack/models/vaes/vae.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
import tensorflow as tf
from tensorflow.keras import layers

from ..utils import as_KerasModel


@as_KerasModel
class VAE(tf.keras.Model):
"""Variational Autoencoder (VAE) model.

Parameters:
encoder: keras model, optional
The encoder network.
decoder: keras model, optional
The decoder network.
latent_dim: int, optional
Dimension of the latent space.
"""

def __init__(self, encoder=None, decoder=None, latent_dim=2, **kwargs):
super().__init__(**kwargs)
super(VAE, self).__init__(**kwargs)

# Dimensionality of the latent space
# Define encoder latent dimension
self.latent_dim = latent_dim

# Initialize encoder and decoder, or use defaults
if encoder is None:
self.encoder = self.default_encoder()

if decoder is None:
self.decoder = self.default_decoder()

def train_step(self, data):

data, _ = data

# Gradient tape for automatic differentiation
with tf.GradientTape() as tape:
# Encode input data and sample from latent space.
# The encoder outputs the mean and log of the variance of the
# Gaussian distribution. The log of the variance is computed
# instead of the variance for numerical stability.
Expand All @@ -32,42 +43,37 @@ def train_step(self, data):
epsilon = tf.random.normal(shape=tf.shape(z_mean))
z = z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Reconstruct the input image
# Decode latent samples and compute reconstruction loss
rdata = self.decoder(z)

# Reconstruction loss
rloss = self.loss(data, rdata)

# KL divergence loss
kl_loss = -0.5 * (
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
)
# Compute KL divergence loss
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

# Total loss
# Compute total loss
loss = rloss + kl_loss

# Compute gradients
# Compute gradients and update model weights
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

# Update weights
self.optimizer.apply_gradients(
zip(grads, self.trainable_weights),
)

# Update metrics
# Update metrics for monitoring
self.compiled_metrics.update_state(data, rdata)

# Return loss values for visualization
return {
"loss": loss,
"reconstruction_loss": rloss,
"kl_loss": kl_loss,
}

def call(self, inputs):
# Use encoder to obtain latent representation
return self.encoder(inputs)

def default_encoder(self):
# Define the default encoder architecture
return tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28, 1)),
Expand All @@ -88,14 +94,13 @@ def default_encoder(self):
layers.Flatten(),
layers.Dense(16),
layers.LeakyReLU(alpha=0.2),
layers.Dense(
self.latent_dim + self.latent_dim, name="z_mean_log_var"
),
layers.Dense(self.latent_dim + self.latent_dim, name="z_mean_log_var"),
],
name="encoder",
)

def default_decoder(self):
# Define the default decoder architecture
return tf.keras.Sequential(
[
tf.keras.Input(shape=(self.latent_dim,)),
Expand Down
1 change: 1 addition & 0 deletions deeptrack/models/waes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .wae import *
Loading