In [6]:
from jax import random, pmap
import jax.numpy as jnp

In [None]:

@jax.tree_util.register_pytree_node_class
class SPDAffineInvariantMetric():

    def __init__(self, k):
        self.k = k 

    def _exp(self, tangent_vec, base_point):
        pass 

    def _log(self, point, base_point):
        pass 

    def _retraction(self, tangent_vec, base_point):
        pass 

    def _parallel_transport(self, tangent_vec, start_point, end_point):
        pass 

    def _vector_transport(self, tangent_vec, start_poitn, end_point):
        pass 

In [14]:
from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 2048,2048
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')

25.786
32.332
27.258
29.000
34.318


In [15]:
from jax import numpy as np, vmap, jit, random
import itertools

key = random.PRNGKey(0)

# dataset
D = random.bernoulli(key, 0.5, shape=[60000, 100])
# queries are the column indices we want to multiply
queries = random.permutation(key, np.array([comb for comb in itertools.combinations(np.arange(100), 3)]))[:1000]

# a single query: multiply the columns we care about and see what proportion of them are ones
def _single_query(D, query):
    return np.sum(np.prod(D[:, query], axis=1))/D.shape[0]

# generate a function that can compute result of some pre-determined subset of queries on the dataset, vmap over queries
def auto_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        return jit(vmap(_single_query, (None, 0)))(D, queries)
    return compute_statistic

# generate a function that can compute result of some pre-determined subset of queries on the dataset, hand vectorize
# over the queries
def hand_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        temp = np.array_split(queries, 10)
        return np.concatenate([
            np.prod(D[:, q], 2).sum(0) for q in temp
        ]) / D.shape[0]
    return compute_statistic

# hand batched statistic function
hb_compute_statistic = hand_batched_preserve_subset_statistic(queries)
# vmap/auto batched statistic function
ab_compute_statistic = auto_batched_preserve_subset_statistic(queries)


In [21]:
%%timeit 
ab_compute_statistic(D).block_until_ready()

25.9 ms ± 2.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
%%timeit 

hb_compute_statistic(D).block_until_ready()

35.3 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
