In [89]:
import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d = 20, 10, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jnp.ones(n) / n
b = jnp.ones(m) / m
# b = jax.random.uniform(rngs[3], (m,))
# a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
geom = pointcloud.PointCloud(x, y)
prob = linear_problem.LinearProblem(geom, a, b)

solver = sinkhorn.Sinkhorn()
out = solver(prob)

dual_potentials = out.to_dual_potentials()



In [91]:

print("values used for f_evaludation: ", dual_potentials.f.keywords)
g_potential = dual_potentials.g
print(type(g_potential))
print("g_potential: ", g_potential(y))

values used for f_evaludation:  {'potential': Array([ 0.43771   , -0.38668346, -0.42691904, -0.25529772, -0.0946463 ,
        0.12588829,  0.02136417, -0.41771907, -0.28404194, -0.02323455],      dtype=float32), 'y': Array([[0.03725219, 0.0924269 ],
       [0.04939151, 0.8678386 ],
       [0.93742704, 0.71651375],
       [0.89610386, 0.64479685],
       [0.5111505 , 0.85136247],
       [0.34265172, 0.6067195 ],
       [0.58651876, 0.5767032 ],
       [0.9267553 , 0.01023901],
       [0.9460478 , 0.548363  ],
       [0.5799898 , 0.75638044]], dtype=float32), 'weights': Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32), 'epsilon': Array(0.08403818, dtype=float32)}
<class 'jax.tree_util.Partial'>
g_potential:  [0.8830315  0.05849338 0.01833196 0.18995856 0.35065666 0.57119215
 0.4666631  0.0274261  0.1612047  0.42207268]


In [28]:
import ot
import numpy as np 
from gurobipy import GRB
import gurobipy as gp

def W2_pot(X, Y):
    M = ot.dist(X, Y)
    print(type(M))
    a, b = jnp.ones((X.shape[0],)) / X.shape[0], jnp.ones((Y.shape[0],)) / Y.shape[0]
    W2_sq = ot.emd2(a, b, M)
    # pi_star = ot.emd(a, b, M)
    return W2_sq

def W2(X, Y):
    m, n = X.shape[0], Y.shape[0]
    model = gp.Model("LP_OptCoupling")
    model.setParam('OutputFlag', 0)
    pi = {}
    for i in range(m):
        for j in range(n):
            pi[i, j] = model.addVar(lb=0.0, ub = 1.0, vtype=GRB.CONTINUOUS, name=f"pi_{i}_{j}")
    model.update()
    
    obj = gp.quicksum(pi[i, j] * np.linalg.norm(X[i] - Y[j])**2 for i in range(m) for j in range(n))
    model.setObjective(obj, GRB.MINIMIZE)

    for j in range(n):
        model.addConstr(gp.quicksum(pi[i, j] for i in range(m)) == 1/n)
    for i in range(m):
        model.addConstr(gp.quicksum(pi[i, j] for j in range(n)) == 1/m)
    model.optimize()
    # pi_star = np.array([pi[i, j].x for i in range(m) for j in range(n)]).reshape((m, n))  
    W2_sq = model.objVal
    return W2_sq


In [29]:
# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m, d = 100, 10, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jnp.ones(n) / n
b = jnp.ones(m) / m

In [32]:
print(W2_pot(x, y))

<class 'jaxlib.xla_extension.ArrayImpl'>
1.1454308


In [53]:
import numpy as np
import matplotlib.pylab as pl
import ot

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

geom = pointcloud.PointCloud(x, y, epsilon = 0.1)

print(f"epsilon: {geom.epsilon}")
prob = linear_problem.LinearProblem(geom) # uniform weights

solver = sinkhorn.Sinkhorn()
out = solver(prob) # EntropicPotential object
dual_potentials = out.to_dual_potentials()

g_potential = dual_potentials.f.keywords['potential']
# the potential g corresponds to the output of EntropicPotntial.f
# c.f. https://ott-jax.readthedocs.io/en/latest/_modules/ott/problems/linear/potentials.html#EntropicPotentials
# Y = dual_potentials.f.keywords['y']
epsilon = dual_potentials.f.keywords['epsilon']

print(f"epsilon: {epsilon}")


epsilon: 0.1
epsilon: 0.1


In [40]:
from scipy.linalg import norm
sample =np.array([0.5, 0.5])
x_tile = np.tile(sample, (n, 1))
print(norm(x_tile - y, axis=1)** 2)

[0.14389363 0.20710975 0.23457538 0.02798691 0.40696356 0.34866166
 0.373453   0.2456308  0.2314102  0.21700819 0.2958     0.20237379
 0.25505906 0.23461138 0.09088869 0.10233277 0.2590178  0.17388576
 0.11190224 0.38737187 0.2846372  0.08286312 0.13343224 0.22849487
 0.23577824 0.25796857 0.02982283 0.2733367  0.2811058  0.04070041
 0.23541781 0.23866093 0.1776222  0.01921717 0.13541923 0.01650825
 0.212506   0.14445087 0.21856578 0.2131603  0.35207427 0.14142059
 0.11686076 0.38433674 0.33401284 0.41602305 0.19954206 0.01084337
 0.00704422 0.16643041 0.25924224 0.10167656 0.21490668 0.02486121
 0.29160318 0.1591314  0.22006577 0.08258775 0.25851774 0.15428253
 0.22448091 0.14961982 0.1665787  0.13515718 0.01356162 0.2757968
 0.22200416 0.0625207  0.42096552 0.21346717 0.2914092  0.24231994
 0.20014521 0.13139543 0.11172502 0.06622326 0.09119834 0.18089163
 0.07927293 0.22868791 0.23226231 0.20834972 0.03030646 0.15636243
 0.31960633 0.10338381 0.21210496 0.22246556 0.24497157 0.21632

In [41]:
sample = np.array([0.5, 0.5])
x_tile = np.tile(sample, (n, 1))
print(x_tile.shape)
exponent_vec = (g_potential - norm(x_tile - y, axis = 1)**2) / epsilon
numerator = y.T @ np.exp(exponent_vec)
denominator = np.sum(np.exp(exponent_vec))
entropic_image = numerator / denominator


(1000, 2)


In [42]:
print(entropic_image)

[0.28774828 0.31717014]


In [45]:
M = 0.01
half_xsq = norm(sample)**2 / 2
regularized_entropic_image = entropic_image + np.exp(1/(half_xsq - M)) * sample
print(regularized_entropic_image)

[32.537792 32.567215]


In [69]:
class MixtureOfGaussians:
### For generating samples from a mixture of Gaussian distributions (underlying barycenter measure) ###

    def __init__(self, dim, weights=None):
        self.truncation = False
        # Default weights if not provided (equally distributed)
        if weights is None:
            self.weights = []
        else:
            self.weights = weights
            self.weights /= np.sum(self.weights)
        
        # Initialize list to record parameters for each Gaussian component
        self.gaussians = []
        self.dim = dim

    def add_gaussian(self, mean, cov):
        self.gaussians.append((mean, cov))
        
    def set_weights(self, weights):
        self.weights = weights
        self.weights /= np.sum(self.weights)

    def set_truncation(self, radius):
        self.truncation = True
        self.radius = radius

    def random_components(self, num_components, seed = 42):
        dim = self.dim
        rng_component = np.random.RandomState(seed)
        for _ in range(num_components):
            mean = (rng_component.rand(dim) - 0.5) * 100
            A = rng_component.rand(dim, dim) - 0.5
            cov = (np.dot(A, A.T) + np.eye(dim)) * 100
            self.add_gaussian(mean, cov)
        weights = rng_component.rand(num_components)
        self.set_weights(weights)

    def sample(self, n, seed = None):
        dim = self.dim
        count = 0
        samples = np.zeros((n, dim))
        rng_sample = np.random.RandomState(seed)
        while count < n:
            choice = rng_sample.choice(len(self.gaussians), p=self.weights)
            mean, cov = self.gaussians[choice]
            sample = rng_sample.multivariate_normal(mean, cov)
            if not self.truncation or np.linalg.norm(sample) <= self.radius:
                samples[count] = sample
                count += 1
        return samples


In [82]:
rng1 = np.random.RandomState(500)
source_sampler = MixtureOfGaussians(2)
source_sampler.random_components(5, seed = 42)
source_sampler.set_truncation(100)
source_samples = source_sampler.sample(2)
print(source_sampler.gaussians)


test_list = []
for i in range(10):
    test_list.append(rng1.rand())
print(test_list)

[(array([-12.54598812,  45.07143064]), array([[106.35546855, -11.37406507],
       [-11.37406507, 123.66629458]])), (array([-44.19163878,  36.61761458]), array([[105.35184432,   4.92992498],
       [  4.92992498, 145.06544963]])), (array([ 33.24426408, -28.76608893]), array([[120.14680559,   5.44474562],
       [  5.44474562, 103.89339803]])), (array([ -6.80549814, -20.87708598]), array([[114.24757466,   2.49281532],
       [  2.49281532, 106.10630041]])), (array([-4.39300158, 28.51759614]), array([[109.03984564,  -3.42105416],
       [ -3.42105416, 121.42476807]]))]
[0.10314551375085124, 0.7398659549256978, 0.7522912198615602, 0.3435628932951692, 0.9899250252469857, 0.5912641309775509, 0.20872419490123362, 0.782172384869658, 0.8155081806265134, 0.689432533728304]


In [95]:
import math
math.exp(1000)

OverflowError: math range error

In [99]:
x = np.array([1, 2, 3, 4])
x_max = np.max(x)
print(x_max)
x_normalized = x - x_max
print(x_normalized)

4
[-3 -2 -1  0]
