diff --git a/deeptrack/models/__init__.py b/deeptrack/models/__init__.py index a2070a5fe..bb756fcfe 100644 --- a/deeptrack/models/__init__.py +++ b/deeptrack/models/__init__.py @@ -7,6 +7,7 @@ from .gans import * from .gnns import * from .vaes import * +from .waes import * # from .mrcnn import * # from .yolov1 import * diff --git a/deeptrack/models/gans/gan.py b/deeptrack/models/gans/gan.py index ec1b5789c..646e1dbcd 100644 --- a/deeptrack/models/gans/gan.py +++ b/deeptrack/models/gans/gan.py @@ -3,15 +3,24 @@ layers = tf.keras.layers - -@as_KerasModel class GAN(tf.keras.Model): + """Generative Adversarial Network (GAN) model. + + Parameters: + discriminator: keras model, optional + The discriminator network. + generator: keras model, optional + The generator network. + latent_dim: int, optional + Dimension of the latent space for random vectors. + """ + def __init__(self, discriminator=None, generator=None, latent_dim=128): super(GAN, self).__init__() + # Initialize discriminator and generator, or use default if not provided if discriminator is None: discriminator = self.default_discriminator() - if generator is None: generator = self.default_generator() @@ -21,9 +30,13 @@ def __init__(self, discriminator=None, generator=None, latent_dim=128): def compile(self, d_optimizer, g_optimizer, loss_fn): super(GAN, self).compile() + + # Set optimizers and loss function for training self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn + + # Define metrics to track during training self.d_loss_metric = tf.keras.metrics.Mean(name="d_loss") self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss") @@ -36,17 +49,18 @@ def train_step(self, real_images): batch_size = tf.shape(real_images)[0] random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) - # Decode them to fake images + # Generate fake images using the generator generated_images = self.generator(random_latent_vectors) - # Combine them with real images + # Combine real and fake images combined_images = tf.concat([generated_images, real_images], axis=0) - # Assemble labels discriminating real from fake images + # Create labels for real and fake images labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0 ) - # Add random noise to the labels - important trick! + + # Add random noise to labels to improve stability labels += 0.05 * tf.random.uniform(tf.shape(labels)) # Train the discriminator @@ -58,14 +72,13 @@ def train_step(self, real_images): zip(grads, self.discriminator.trainable_weights) ) - # Sample random points in the latent space + # Generate new random latent vectors random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) - # Assemble labels that say "all real images" + # Create labels indicating "all real images" for generator training misleading_labels = tf.zeros((batch_size, 1)) - # Train the generator (note that we should *not* update the weights - # of the discriminator)! + # Train the generator while keeping discriminator weights fixed with tf.GradientTape() as tape: predictions = self.discriminator(self.generator(random_latent_vectors)) g_loss = self.loss_fn(misleading_labels, predictions) @@ -75,12 +88,19 @@ def train_step(self, real_images): # Update metrics self.d_loss_metric.update_state(d_loss) self.g_loss_metric.update_state(g_loss) + + # Return updated loss metrics return { "d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result(), } + def call(self, inputs): + # Run generator + return self.generator(inputs) + def default_generator(self, latent_dim=128): + # Define the default generator architecture return tf.keras.Sequential( [ tf.keras.Input(shape=(latent_dim,)), @@ -98,6 +118,7 @@ def default_generator(self, latent_dim=128): ) def default_discriminator(self): + # Define the default discriminator architecture return tf.keras.Sequential( [ tf.keras.Input(shape=(64, 64, 3)), @@ -112,4 +133,4 @@ def default_discriminator(self): layers.Dense(1, activation="sigmoid"), ], name="discriminator", - ) + ) \ No newline at end of file diff --git a/deeptrack/models/vaes/vae.py b/deeptrack/models/vaes/vae.py index 25c6e72ae..2a5e10fd8 100644 --- a/deeptrack/models/vaes/vae.py +++ b/deeptrack/models/vaes/vae.py @@ -1,17 +1,27 @@ import tensorflow as tf from tensorflow.keras import layers - from ..utils import as_KerasModel -@as_KerasModel class VAE(tf.keras.Model): + """Variational Autoencoder (VAE) model. + + Parameters: + encoder: keras model, optional + The encoder network. + decoder: keras model, optional + The decoder network. + latent_dim: int, optional + Dimension of the latent space. + """ + def __init__(self, encoder=None, decoder=None, latent_dim=2, **kwargs): - super().__init__(**kwargs) + super(VAE, self).__init__(**kwargs) - # Dimensionality of the latent space + # Define encoder latent dimension self.latent_dim = latent_dim + # Initialize encoder and decoder, or use defaults if encoder is None: self.encoder = self.default_encoder() @@ -19,10 +29,11 @@ def __init__(self, encoder=None, decoder=None, latent_dim=2, **kwargs): self.decoder = self.default_decoder() def train_step(self, data): - data, _ = data + # Gradient tape for automatic differentiation with tf.GradientTape() as tape: + # Encode input data and sample from latent space. # The encoder outputs the mean and log of the variance of the # Gaussian distribution. The log of the variance is computed # instead of the variance for numerical stability. @@ -32,32 +43,25 @@ def train_step(self, data): epsilon = tf.random.normal(shape=tf.shape(z_mean)) z = z_mean + tf.exp(0.5 * z_log_var) * epsilon - # Reconstruct the input image + # Decode latent samples and compute reconstruction loss rdata = self.decoder(z) - - # Reconstruction loss rloss = self.loss(data, rdata) - # KL divergence loss - kl_loss = -0.5 * ( - 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) - ) + # Compute KL divergence loss + kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) - # Total loss + # Compute total loss loss = rloss + kl_loss - # Compute gradients + # Compute gradients and update model weights grads = tape.gradient(loss, self.trainable_weights) + self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) - # Update weights - self.optimizer.apply_gradients( - zip(grads, self.trainable_weights), - ) - - # Update metrics + # Update metrics for monitoring self.compiled_metrics.update_state(data, rdata) + # Return loss values for visualization return { "loss": loss, "reconstruction_loss": rloss, @@ -65,9 +69,11 @@ def train_step(self, data): } def call(self, inputs): + # Use encoder to obtain latent representation return self.encoder(inputs) def default_encoder(self): + # Define the default encoder architecture return tf.keras.Sequential( [ tf.keras.Input(shape=(28, 28, 1)), @@ -88,14 +94,13 @@ def default_encoder(self): layers.Flatten(), layers.Dense(16), layers.LeakyReLU(alpha=0.2), - layers.Dense( - self.latent_dim + self.latent_dim, name="z_mean_log_var" - ), + layers.Dense(self.latent_dim + self.latent_dim, name="z_mean_log_var"), ], name="encoder", ) def default_decoder(self): + # Define the default decoder architecture return tf.keras.Sequential( [ tf.keras.Input(shape=(self.latent_dim,)), diff --git a/deeptrack/models/waes/__init__.py b/deeptrack/models/waes/__init__.py new file mode 100644 index 000000000..95459837d --- /dev/null +++ b/deeptrack/models/waes/__init__.py @@ -0,0 +1 @@ +from .wae import * \ No newline at end of file diff --git a/deeptrack/models/waes/wae.py b/deeptrack/models/waes/wae.py new file mode 100644 index 000000000..263a00f1a --- /dev/null +++ b/deeptrack/models/waes/wae.py @@ -0,0 +1,312 @@ +import tensorflow as tf +from tensorflow.keras import layers +from ..utils import as_KerasModel + + +class WAE(tf.keras.Model): + """Wasserstein Autoencoder based on either Generative Adversarial Network (WAE-GAN) model or maximum mean discrepancy (WAE-MMD). + + Parameters: + regularizer: 'mmd' or 'gan', default is mmd + encoder: keras model, optional + The encoder network. + decoder: keras model, optional + The decoder network. + discriminator: keras model, optional + The discriminator network. + latent_dim: int, optional + Dimension of the latent space. + lambda_: float, optional + Hyperparameter for regularization. + sigma_z: float, optional + Standard deviation for sampling in the latent space. + """ + + def __init__( + self, + regularizer="mmd", + encoder=None, + decoder=None, + discriminator=None, + latent_dim=8, + lambda_=10.0, + sigma_z=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + # Define latent dimension, and hyperparameters + self.regularizer = regularizer + self.latent_dim = latent_dim # For probabilistic encoder set as 2*latent_dim + self.lambda_ = lambda_ + self.sigma_z = sigma_z + + # Initialize encoder, decoder, and discriminator, or use defaults + if encoder is None: + encoder = self.default_encoder() + + if decoder is None: + decoder = self.default_decoder() + + if self.regularizer=="gan": + if discriminator is None: + discriminator = self.default_discriminator() + + self.encoder = encoder + self.decoder = decoder + self.discriminator = discriminator + + def compile( + self, enc_optimizer=None, dec_optimizer=None, disc_optimizer=None, loss_fn=None + ): + super().compile() + + # Set optimizers and loss function for training if not provided + if enc_optimizer is None: + enc_optimizer = tf.keras.optimizers.Adam( + learning_rate=1e-3, beta_1=0.5, beta_2=0.999 + ) + + if dec_optimizer is None: + dec_optimizer = tf.keras.optimizers.Adam( + learning_rate=1e-3, beta_1=0.5, beta_2=0.999 + ) + + if self.regularizer=="gan": + if disc_optimizer is None: + disc_optimizer = tf.keras.optimizers.Adam( + learning_rate=5e-4, beta_1=0.5, beta_2=0.999 + ) + + if loss_fn is None: + loss_fn = tf.keras.losses.MeanSquaredError() + + self.enc_optim = enc_optimizer + self.dec_optim = dec_optimizer + self.disc_optim = disc_optimizer + self.loss_fn = loss_fn + + @tf.function + def mmd_penalty(self, pz, qz, batch_size): + # Estimator of the MMD with the IMQ kernel as in + # https://github.com/tolstikhin/wae/blob/master/wae.py#L233 and in + # https://github.com/w00zie/wae_mnist/blob/master/train_mmd.py + + # Here the property that the sum of positive definite kernels is + # still a p.d. kernel is used. Various kernels calculated at different + # scales are summed together in order to "simultaneously look at various + # scales" [https://github.com/tolstikhin/wae/issues/2]. + + norms_pz = tf.reduce_sum(tf.square(pz), axis=1, keepdims=True) + dotprods_pz = tf.matmul(pz, pz, transpose_b=True) + distances_pz = norms_pz + tf.transpose(norms_pz) - 2.0 * dotprods_pz + + norms_qz = tf.reduce_sum(tf.square(qz), axis=1, keepdims=True) + dotprods_qz = tf.matmul(qz, qz, transpose_b=True) + distances_qz = norms_qz + tf.transpose(norms_qz) - 2.0 * dotprods_qz + + dotprods = tf.matmul(qz, pz, transpose_b=True) + distances = norms_qz + tf.transpose(norms_pz) - 2.0 * dotprods + + cbase = tf.constant(2.0 * self.latent_dim * self.sigma_z) + stat = tf.constant(0.0) + nf = tf.cast(batch_size, dtype=tf.float32) + + for scale in tf.constant([0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]): + C = cbase * scale + res1 = C / (C + distances_qz) + res1 += C / (C + distances_pz) + res1 = tf.multiply(res1, 1.0 - tf.eye(batch_size)) + res1 = tf.reduce_sum(res1) / (nf * nf - nf) + res2 = C / (C + distances) + res2 = tf.reduce_sum(res2) * 2.0 / (nf * nf) + stat += res1 - res2 + return stat + + def train_step(self, data): + data, _ = data + + batch_size = tf.shape(data)[0] + + # Gradient tape for automatic differentiation + with tf.GradientTape(persistent=True) as tape: + # Encode input data + q_z = self.encoder(data) + # For probabilistic encoder sample from latent space + # z_mean, z_log_var = tf.split(self.encoder(data), 2, axis=1) + # epsilon = tf.random.normal(shape=tf.shape(z_mean)) + # q_z = z_mean + tf.exp(0.5 * z_log_var) * epsilon + + # Decode sampled points and compute discriminator output + x_hat = self.decoder(q_z) + + # Compute reconstruction loss + rloss = self.loss_fn(data, x_hat) + + # Compute penalty for regularization + if self.regularizer=="gan": + d_qz = self.discriminator(q_z) + penalty = tf.keras.losses.binary_crossentropy( + tf.ones_like(d_qz), d_qz + ) + elif self.regularizer=="mmd": + p_z = tf.random.normal( + shape=(batch_size, self.latent_dim), + stddev=tf.sqrt(self.sigma_z), + ) + penalty = self.mmd_penalty(p_z, q_z, batch_size) + loss = rloss + tf.reduce_mean(self.lambda_ * penalty) + + # Compute gradients and update encoder and decoder weights + enc_grads = tape.gradient(loss, self.encoder.trainable_weights) + dec_grads = tape.gradient(loss, self.decoder.trainable_weights) + self.enc_optim.apply_gradients(zip(enc_grads, self.encoder.trainable_weights)) + self.dec_optim.apply_gradients(zip(dec_grads, self.decoder.trainable_weights)) + + # Sample points from the latent space for the discriminator + if self.regularizer=="gan": + with tf.GradientTape() as tape: + p_z = tf.random.normal( + shape=(batch_size, self.latent_dim), + stddev=tf.sqrt(self.sigma_z), + ) + d_pz = self.discriminator(p_z) + + q_z = self.encoder(data) + # For probabilistic encoder sample from latent space + # z_mean, z_log_var = tf.split(self.encoder(data), 2, axis=1) + # epsilon = tf.random.normal(shape=tf.shape(z_mean)) + # q_z = z_mean + tf.exp(0.5 * z_log_var) * epsilon + d_qz = self.discriminator(q_z) + + # Compute losses for real and fake samples and discriminator loss + real_loss = tf.keras.losses.binary_crossentropy( + tf.ones_like(d_pz), d_pz + ) + fake_loss = tf.keras.losses.binary_crossentropy( + tf.zeros_like(d_qz), d_qz + ) + disc_loss = self.lambda_ * ( + tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss) + ) + + # Compute gradients and update discriminator weights + disc_grads = tape.gradient( + disc_loss, self.discriminator.trainable_weights + ) + self.disc_optim.apply_gradients( + zip(disc_grads, self.discriminator.trainable_weights) + ) + + # Update metrics for visualization + self.compiled_metrics.update_state(data, x_hat) + + # Return various loss values for monitoring + if self.regularizer=="gan": + return { + "loss": loss, + "reconstruction_loss": rloss, + "discriminator_loss": disc_loss, + } + elif self.regularizer=="mmd": + return { + "loss": loss, + "reconstruction_loss": rloss, + } + + def call(self, inputs): + # Use encoder to obtain latent representation + return self.encoder(inputs) + + def default_encoder(self): + # Define the default encoder architecture + return tf.keras.Sequential( + [ + tf.keras.Input(shape=(28, 28, 1)), + layers.Conv2D( + 128, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Conv2D( + 256, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Conv2D( + 512, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Conv2D( + 1024, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Flatten(), + layers.Dense(self.latent_dim), + ], + name="encoder", + ) + + def default_decoder(self): + # Define the default decoder architecture + return tf.keras.Sequential( + [ + tf.keras.Input(shape=(self.latent_dim,)), + layers.Dense(7 * 7 * 1024), + layers.Reshape((7, 7, 1024)), + layers.Conv2DTranspose( + 512, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Conv2DTranspose( + 256, + kernel_size=4, + strides=2, + padding="same", + ), + layers.BatchNormalization(), + layers.ReLU(), + layers.Conv2D( + 1, + kernel_size=4, + padding="same", + ), + ], + name="decoder", + ) + + def default_discriminator(self): + # Define the default discriminator architecture for WAE_GAN + return tf.keras.Sequential( + [ + tf.keras.Input(shape=(self.latent_dim,)), + layers.Dense(512), + layers.ReLU(), + layers.Dense(512), + layers.ReLU(), + layers.Dense(512), + layers.ReLU(), + layers.Dense(512), + layers.ReLU(), + layers.Dense(1, activation="sigmoid"), + ], + name="discriminator", + ) diff --git a/deeptrack/test/test_models.py b/deeptrack/test/test_models.py index 0bc545623..61d54a0f5 100644 --- a/deeptrack/test/test_models.py +++ b/deeptrack/test/test_models.py @@ -79,6 +79,58 @@ def test_UNet(self): model.predict(np.zeros((1, 64, 64, 1))) + def test_GAN(self): + model = models.GAN( + discriminator=None, + generator=None, + latent_dim=128, + ) + model.compile( + d_optimizer=tf.keras.optimizers.Adam(), + g_optimizer=tf.keras.optimizers.Adam(), + loss_fn=tf.keras.losses.MeanAbsoluteError(), + ) + self.assertIsInstance(model.discriminator, tf.keras.Sequential) + self.assertIsInstance(model.generator, tf.keras.Sequential) + + prediction = model.predict(np.zeros((1, 128))) + self.assertEqual(prediction.shape, (1, 64, 64, 3)) + + def test_VAE(self): + model = models.VAE( + encoder=None, + decoder=None, + latent_dim=2, + ) + self.assertIsInstance(model.encoder, tf.keras.Sequential) + self.assertIsInstance(model.decoder, tf.keras.Sequential) + + pred_enc = model.encoder.predict(np.zeros((1, 28, 28, 1))) + self.assertEqual(pred_enc.shape, (1, 4)) + + pred_dec = model.decoder.predict(np.zeros((1, 2))) + self.assertEqual(pred_dec.shape, (1, 28, 28, 1)) + + def test_WAE(self): + model = models.WAE( + regularizer="mmd", + encoder=None, + decoder=None, + discriminator=None, + latent_dim=2, + lambda_=10.0, + sigma_z=1.0, + ) + model.compile() + self.assertIsInstance(model.encoder, tf.keras.Sequential) + self.assertIsInstance(model.decoder, tf.keras.Sequential) + + pred_enc = model.encoder.predict(np.zeros((1, 28, 28, 1))) + self.assertEqual(pred_enc.shape, (1, 2)) + + pred_dec = model.decoder.predict(pred_enc) + self.assertEqual(pred_dec.shape, (1, 28, 28, 1)) + def test_RNN(self): model = models.rnn( input_shape=(None, 64, 64, 1), @@ -135,9 +187,7 @@ def test_MAGIK(self): graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) model(graph) @@ -165,9 +215,7 @@ def test_CTMAGIK(self): graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) prediction = model(graph) @@ -199,9 +247,7 @@ def test_MAGIK_with_MaskedFGNN(self): graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) model(graph) @@ -229,9 +275,7 @@ def test_MPGNN(self): graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) model(graph) @@ -256,16 +300,12 @@ def test_MPGNN_readout(self): ) self.assertIsInstance(model, models.KerasModel) - self.assertIsInstance( - model.layers[-4], tf.keras.layers.GlobalAveragePooling1D - ) + self.assertIsInstance(model.layers[-4], tf.keras.layers.GlobalAveragePooling1D) graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) prediction = model(graph) @@ -296,13 +336,11 @@ def test_GRU_MPGNN(self): graph = ( tf.random.uniform((8, 10, 7)), # Node features tf.random.uniform((8, 50, 1)), # Edge features - tf.random.uniform( - (8, 50, 2), minval=0, maxval=10, dtype=tf.int32 - ), # Edges + tf.random.uniform((8, 50, 2), minval=0, maxval=10, dtype=tf.int32), # Edges tf.random.uniform((8, 50, 2)), # Edge dropouts ) model(graph) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()