in hwc format (i think)

In [9]:
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm_notebook

In [20]:
@tf.function
def shifted_meshgrid(grid, centers):
    centers = tf.reshape(centers, (-1, 1, 1, 2))
    return grid - centers

In [26]:
shifted_meshgrid(5, tf.constant([[3.0, 3.0],
                                 [1.0,1.0]]))[0, :, :, 0]

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[-3., -2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.,  1.]], dtype=float32)>

In [2]:
# shape: self-explanatory, eps: small constant for numerical stability, gravity: g(r)
def gen_calculate_gravity(shape=(128, 128), eps=1e-8, gravity_func=(lambda x: tf.pow(x, -2))):
    @tf.function(input_signature=[tf.TensorSpec(shape=shape, dtype=tf.float32), tf.TensorSpec(shape=(2,), dtype=tf.float32)])
    def calculate_gravity(heatmap, center): # heatmap: mass distribution, center: position of particle of interest
        print("tracing")
        grid = tf.stack(
            tf.meshgrid(
                tf.range(shape[0], dtype=tf.float32),
                tf.range(shape[1], dtype=tf.float32),
            ),
            axis=-1
        )
        distances = grid - center
        r = tf.math.reduce_euclidean_norm(distances, axis=-1, keepdims=True)

        # the funny so we don't divide by 0
        r = r - eps
        r = tf.nn.relu(r) # this is secretly a neural network
        r = r + eps

        f = gravity_func(r)
        f = f * tf.expand_dims(heatmap, -1)
        f = f * (distances / r) # trigonometry but it's not
        return tf.reduce_sum(f, axis=[0, 1])

    # todo batched thing

    return calculate_gravity

In [3]:
def gen_time_step(force=(lambda x: 0), gravity=(lambda x, y: 0), n=100, hm_res=128, dt=1e-3):
    @tf.function(input_signature=[tf.TensorSpec(shape=(n, 2)), tf.TensorSpec(shape=(n,2))])
    def time_step(x, v): # x: position, v: velocity
        index = x[:, 1] * hm_res + x[:, 0] # xy indexing instead of ij
        heatmap = tf.histogram_fixed_width(index, value_range=(0, hm_res * hm_res), nbins=hm_res * hm_res)
        heatmap = tf.cast(heatmap, tf.float32)
        heatmap = tf.reshape(heatmap, (hm_res, hm_res))

        # todo if number of particles is small enough can just do a batched thing
        g = tf.map_fn(lambda r: gravity(heatmap, r), x)

        v = v + (g + force(x)) * dt

        x = x + v * dt

        return x, v
    return time_step

In [11]:
time_step = gen_time_step(gravity=gen_calculate_gravity(shape=(128, 128)), n=10000)

In [12]:
%%time
x, v = tf.random.uniform(minval=0, maxval=128, shape=(10000, 2)), tf.random.uniform(minval=0, maxval=1, shape=(10000, 2))
for i in tqdm_notebook(range(600)):
    x, v = time_step(x, v)
print("done")

  0%|          | 0/600 [00:00<?, ?it/s]

tracing


KeyboardInterrupt: 