In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:
import jax
import ott
import jax.numpy as jnp
from typing import Tuple, Optional

In [3]:
def sample_indices_from_tmap(key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, n_samples: Optional[int]) -> Tuple[jnp.array, jnp.array]:
    n_samples = n_samples if isinstance(n_samples, int) else tmat.shape[1]
    pi_star_inds = jax.random.categorical(
                key, logits=jnp.log(tmat.flatten()), shape=(n_samples,)
            )
    return pi_star_inds // tmat.shape[1], pi_star_inds % tmat.shape[1]

In [4]:
def sample_indices_from_tmap(key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Optional[int]) -> Tuple[jnp.array, jnp.array]:
    indices_per_row = jax.vmap(lambda tmat: jax.random.choice(key=key, a=jnp.arange(len(tmat)), p=tmat, shape=(k_samples_per_x,)), in_axes=0, out_axes=0)(tmat)
    return jnp.repeat(jnp.arange(tmat.shape[0]), k_samples_per_x), indices_per_row % tmat.shape[1]



In [62]:
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem

x = jax.random.uniform(jax.random.PRNGKey(0), (6,1))
x= jnp.sort(x, axis=1)

y = jax.random.uniform(jax.random.PRNGKey(1), (6,1))
y = jnp.sort(y, axis=1)

ot_solver = ott.solvers.linear.sinkhorn.Sinkhorn()
geom = pointcloud.PointCloud(x, y, epsilon=1e-1)
out = ot_solver(linear_problem.LinearProblem(geom))

                


In [63]:
s,t = sample_indices_from_tmap(jax.random.PRNGKey(0), out.matrix, 4)

In [64]:
random_batches = jax.random.uniform(jax.random.PRNGKey(0), ((*t.shape, 1)))

In [65]:
random_batches.shape

(6, 4, 1)

In [77]:
y[t].shape

(6, 4, 1)

In [78]:
a = jax.random.split(jax.random.PRNGKey(0), len(random_batches))
a.shape

(6, 2)

In [68]:
def sinkhorn(x, y, key):
    ot_solver = ott.solvers.linear.sinkhorn.Sinkhorn()
    geom = pointcloud.PointCloud(x, y, epsilon=1e-2, scale_cost="mean")
    out = ot_solver(linear_problem.LinearProblem(geom))
    inds_source, inds_target = sample_indices_from_tmap(key, out.matrix, 1)
    return x[inds_source], y[inds_target]
    

In [69]:
r1, r2 = jax.vmap(sinkhorn, 0, 0)(random_batches, y[t], a)

In [70]:
r1.shape, r2.shape

((6, 4, 1), (6, 4, 1, 1))

In [75]:
r1.flatten()

Array([0.6433916 , 0.18188512, 0.02240455, 0.563781  , 0.5526401 ,
       0.0958724 , 0.34253013, 0.03644359, 0.08744538, 0.7909105 ,
       0.35205448, 0.53364205, 0.02900076, 0.4168595 , 0.5802449 ,
       0.91486526, 0.27414513, 0.14991808, 0.9383501 , 0.5209162 ,
       0.51207185, 0.90618336, 0.7309413 , 0.95533276], dtype=float32)

In [76]:
r1

Array([[[0.6433916 ],
        [0.18188512],
        [0.02240455],
        [0.563781  ]],

       [[0.5526401 ],
        [0.0958724 ],
        [0.34253013],
        [0.03644359]],

       [[0.08744538],
        [0.7909105 ],
        [0.35205448],
        [0.53364205]],

       [[0.02900076],
        [0.4168595 ],
        [0.5802449 ],
        [0.91486526]],

       [[0.27414513],
        [0.14991808],
        [0.9383501 ],
        [0.5209162 ]],

       [[0.51207185],
        [0.90618336],
        [0.7309413 ],
        [0.95533276]]], dtype=float32)