In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
code_batch = tf.random.normal((32, 10))

In [7]:
_, n_latent = code_batch.get_shape()
rtol = 20.0
atol = 2.0

In [41]:
# distance vector calculation
tri_mask = tf.linalg.band_part(tf.ones((n_latent, n_latent), tf.float32), -1, 0)
batch_masked = tf.multiply(tri_mask[:, tf.newaxis, :], code_batch[tf.newaxis, ...])
X_sq = tf.reduce_sum(batch_masked * batch_masked, axis=2, keepdims=True)
pdist_vector = (
    X_sq
    + tf.transpose(X_sq, [0, 2, 1])
    - 2 * tf.matmul(batch_masked, tf.transpose(batch_masked, [0, 2, 1]))
)
all_dists = pdist_vector

In [42]:
all_ra = tf.sqrt(1 / (tf.range(1, 1 + n_latent, dtype=tf.float32))) * tf.squeeze(
    tf.reduce_sum(
        tf.square(tf.math.reduce_std(batch_masked, axis=1, keepdims=True)),
        axis=2,
    )
)

In [43]:
k = 1
all_dists = tf.clip_by_value(all_dists, 1e-14, tf.reduce_max(all_dists))
_, inds = tf.math.top_k(-all_dists, int(k + 1))
neighbor_dists_d = tf.gather(all_dists, inds, batch_dims=-1)
neighbor_new_dists = tf.gather(all_dists[1:], inds[:-1], batch_dims=-1)
scaled_dist = tf.sqrt(
    (neighbor_new_dists - neighbor_dists_d[:-1]) / neighbor_dists_d[:-1]
)

In [47]:
is_false_change = scaled_dist > rtol
is_large_jump = neighbor_new_dists > atol * all_ra[:-1, tf.newaxis, tf.newaxis]
is_false_neighbor = tf.math.logical_or(is_false_change, is_large_jump)
total_false_neighbors = tf.cast(is_false_neighbor, tf.int32)[..., 1 : (k + 1)]