In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [37]:
"""Define the distance function here: d: R^n x R^n -> R,
The distance function must satisfy the following:
1. Positivity: d(x, y) >= 0 for all x, y in R^n s.t. x != y, otherwise, it is zero.
2. Symmetric: d(x, y) = d(y, x) for all x, y in R^n.
3. Triangle inequality: d(x, z) <= d(x, y) + d(y, z) for all x, y, z in R^n.
"""

def dist(x, y):
    # return np.linalg.norm(x - y) # Example: Euclidean distance
    xsq = x**2
    ysq = y**2
    sum_sq = xsq + ysq
    diff_sq = xsq - ysq
    return np.sum( np.abs(diff_sq) / sum_sq )

def notdist(x, y):
    return np.abs(np.sqrt(np.sum(x**2)) - np.sqrt(np.sum(y**2))) # Example: Not a distance

In [38]:
N = 100  # Number of points
lb, ub = -1, 1  # Lower and upper bounds for the points
v, w = np.random.uniform(lb, ub, (N, 2)), np.random.uniform(lb, ub, (N, 2))

In [39]:
# Test if the distance function is valid
def test_dist(dist, v, w):
    """
    Input:
    dist: distance function
    v: array of points
    w: array of points

    Output:
    None: if the distance function is valid
    Will return an AssertionError if the distance function is not valid
    """
    for i in range(N):
        for j in range(N):
            if i != j:
                assert dist(v[i], v[j]) >= 0, "Distance is negative"
                assert dist(v[i], v[j]) == dist(v[j], v[i]), "Distance is not symmetric"
                for k in range(N):
                    if k != i and k != j:
                        assert dist(v[i], v[k]) <= dist(v[i], v[j]) + dist(v[j], v[k]), "Triangle inequality violated"

In [40]:
test_dist(dist, v, w)

In [22]:
test_dist(notdist, v, w)

AssertionError: Triangle inequality violated