# First Tests and Ansatz

In [199]:
## Imports
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import pandas as pd
import matplotlib.pyplot as plt

In [200]:
## Settings
tf.random.set_seed(42)
np.random.seed(42)

## 01 Running an easy 2 Body Problem in a gradient tape

In [201]:
@tf.function
def inverse_pairwise_distances(X) -> tf.Tensor:
    """
    Returns a matrix with the inverse paired euclidean distances of the instances (rows) in X. The diagonals get imputed with zeros.

    Parameters
    X : tf.Tensor of shape (N, M)
        Tensor with instances to compue the pairwise distances between each other. Gets casted to float32.
   
    Returns
    tf.Tensor of shape (N, N)
        Pairwise distances in the upper or lower triangular part.
    """
    X = tf.cast(X, tf.float32)
    r = tf.reduce_sum(X * X, 1)
    r = tf.reshape(r, [-1, 1])
    D = r - 2.*tf.matmul(X, tf.transpose(X)) + tf.transpose(r)
    D = tf.linalg.set_diag(D, tf.repeat(np.inf, X.shape[0]))
    D = tf.linalg.band_part(1./D, 0, -1)
    return D

@tf.function
def pairwise_mass_product(M) -> tf.Tensor:
    Msq = tf.expand_dims(M, axis=0) * tf.expand_dims(M, axis=1)
    return Msq

@tf.function
def compute_potential(X: tf.Tensor, M: tf.Tensor) -> tf.Tensor:
    D = inverse_pairwise_distances(X)
    Msq = pairwise_mass_product(M)
    return tf.reduce_sum(D*Msq)
    

In [202]:
X0 = tf.Variable([[0., 0., 0.], [1., 1., 1.], [2., 2., 2.]])
M0 = tf.constant([3., 2., 1.])

In [203]:
inverse_pairwise_distances(X0)

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0.        , 0.33333334, 0.08333334],
       [0.        , 0.        , 0.33333334],
       [0.        , 0.        , 0.        ]], dtype=float32)>

In [204]:
compute_potential(X0, M0)

<tf.Tensor: shape=(), dtype=float32, numpy=2.9166667>

In [205]:
with tf.GradientTape() as tape:
    tape.watch(X0)
    pot_energy = compute_potential(X0, M0)
grad = tape.gradient(pot_energy, X0)
print(grad)

tf.Tensor(
[[ 1.4166667  1.4166667  1.4166667]
 [-0.8888889 -0.8888889 -0.8888889]
 [-0.5277778 -0.5277778 -0.5277778]], shape=(3, 3), dtype=float32)


In [206]:
pot_energy

<tf.Tensor: shape=(), dtype=float32, numpy=2.9166667>