In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:
import jax
import ott
import jax.numpy as jnp
from typing import Tuple, Optional

In [3]:
def sample_indices_from_tmap(key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, n_samples: Optional[int]) -> Tuple[jnp.array, jnp.array]:
    n_samples = n_samples if isinstance(n_samples, int) else tmat.shape[1]
    pi_star_inds = jax.random.categorical(
                key, logits=jnp.log(tmat.flatten()), shape=(n_samples,)
            )
    return pi_star_inds // tmat.shape[1], pi_star_inds % tmat.shape[1]

In [28]:
def sample_indices_from_tmap(key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Optional[int]) -> Tuple[jnp.array, jnp.array]:
    indices_per_row = jax.vmap(lambda tmat: jax.random.choice(key=key, a=jnp.arange(len(tmat)), p=tmat, shape=(k_samples_per_x,)), in_axes=0, out_axes=0)(tmat)
    return jnp.repeat(jnp.arange(tmat.shape[0]), k_samples_per_x), indices_per_row % tmat.shape[1]



In [35]:
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem

x = jax.random.uniform(jax.random.PRNGKey(0), (20,1))

y = jax.random.uniform(jax.random.PRNGKey(1), (20,1))

ot_solver = ott.solvers.linear.sinkhorn.Sinkhorn()
geom = pointcloud.PointCloud(x, y, epsilon=1e-3)
out = ot_solver(linear_problem.LinearProblem(geom))

                


In [39]:
s,t = sample_indices_from_tmap(jax.random.PRNGKey(0), out.matrix, 4)

In [40]:
t

Array([[ 7, 13,  7, 13],
       [ 5, 14, 11, 14],
       [ 0,  3,  1,  3],
       [ 9,  9,  9,  9],
       [ 4, 16, 12, 16],
       [ 5, 11,  5,  5],
       [ 3, 10, 10, 10],
       [ 6, 15,  6,  6],
       [ 2,  8,  2,  8],
       [ 4, 16,  4, 16],
       [ 5, 14, 11, 14],
       [ 4, 18, 18, 18],
       [ 0, 15,  1, 15],
       [ 0, 15,  1, 10],
       [ 2, 17,  8,  8],
       [ 7, 13,  7, 13],
       [ 9, 19, 19, 19],
       [ 2, 17,  8, 17],
       [ 4, 16, 12, 12],
       [ 0,  3,  1,  3]], dtype=int32)

In [41]:
jnp.reshape(t, 4*len(t))


Array([ 7, 13,  7, 13,  5, 14, 11, 14,  0,  3,  1,  3,  9,  9,  9,  9,  4,
       16, 12, 16,  5, 11,  5,  5,  3, 10, 10, 10,  6, 15,  6,  6,  2,  8,
        2,  8,  4, 16,  4, 16,  5, 14, 11, 14,  4, 18, 18, 18,  0, 15,  1,
       15,  0, 15,  1, 10,  2, 17,  8,  8,  7, 13,  7, 13,  9, 19, 19, 19,
        2, 17,  8, 17,  4, 16, 12, 12,  0,  3,  1,  3], dtype=int32)

In [27]:
jnp.repeat(jnp.arange(tmat.shape[0]), 4) 

Array([ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,
        4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  7,  8,  8,
        8,  8,  9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12,
       12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16,
       17, 17, 17, 17, 18, 18, 18, 18, 19, 19, 19, 19], dtype=int32)