In [None]:
!pip install POT

In [None]:
import numpy as np
import numpy.random as random

import matplotlib.pyplot as plt
import ot

In [None]:
def ndot(a,b):
    return a[0]*b[0]-a[1]*b[1]

In [None]:
def clamp(value, min_val, max_val):
    return np.maximum(min_val, np.minimum(value, max_val))

In [None]:
def sdf_hexagram(p,r=0.5):
    k =np.array[-0.5,0.8660254038,0.5773502692,1.7320508076]
    p=np.abs(p)
    p-=2.0*min(np.dot(k[:2],p),0.0)*k[:2]
    p-=2.0*min(np.dot(k[1::-1],p),0.0)*k[1::-1]
    p-=np.array([clamp(p[:,0],r*k[2]),r*k[3],r])
    return np.linalg.norm(p,axis=1)*np.sign(p[:,1])

In [None]:
def sdf_rhombus(p,b):
    p=np.abs(p)
    h=clamp(ndot(b-2.0*p,b)/np.dot(b,b),-1.0,1.0)
    d=np.linalg.norm(p - 0.5 * b * np.array([1.0 - h, 1.0 + h]))
    return d*np.sign(p[0]*b[1]+p[1]*b[0]-b[0]*b[1])

In [None]:
def gen_grid(resolution, low=-1.0, high=1.0):
    idx = np.linspace(low, high, num=resolution)
    x, y = np.meshgrid(idx, idx)
    V = np.concatenate((x.reshape((-1,1)), y.reshape((-1,1))), 1)

    return V

In [None]:
def sample_and_normalize(f, grid, grid_size):
    '''
    Sample f on the grid and normalize it.
    Assume f>0 outside and <0 inside.
    '''
    print(grid.shape)
    fv = f(grid)
    print(fv.shape)

    # >0 inside
    fv = -fv

    # f is the characteristic function for {f>0}
    fv[fv>=0.0] = 1.0
    fv[fv<0.0] = 0.0

    total_sum = np.sum(fv)
    if total_sum > 0:
        fv = fv / total_sum
    else:
        raise ValueError("The sum of the function values is zero; normalization is not possible.")

    # reshape to have the same shape as grid
    fv = fv.reshape(grid_size, grid_size)

    return fv

In [None]:
grid_size = 64
grid = gen_grid(grid_size,-2.0,2.0)

# f1 and f2 are prob. distribution corresponding to f1 and f2
f1 = sample_and_normalize(sdf_triangle, grid, grid_size)
f2 = sample_and_normalize(sdf_pentagon, grid, grid_size)

A = np.array([f1,f2])

In [None]:
nb_images = 5
reg = 0.004

v1 = np.array((1, 0))
v2 = np.array((0, 1))

fig, axes = plt.subplots(1, nb_images, figsize=(7, 7))
plt.suptitle("Convolutional Wasserstein Barycenters in POT")
cm = "Blues"

for i in range(nb_images):
    tx = float(i) / (nb_images - 1)

    weights = (1 - tx) * v1 + tx * v2

    if i == 0:
        axes[i].imshow(f1, cmap=cm)
    elif i == (nb_images - 1):
        axes[i].imshow(f2, cmap=cm)
    else:
        # call to barycenter computation
        axes[i].imshow(
            ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm
        )
    axes[i].axis("off")

plt.tight_layout()
plt.show()