In [1]:
import timeit  # noqa: F401
from functools import partial
from typing import List

import jax
import jax.numpy as jnp
import numpy as np
import signatory
import torch


# jax.config.update('jax_platform_name', 'cpu')


@partial(jax.jit, static_argnames="depth")
def restricted_exp(input: jnp.ndarray, depth: int):
    """Restricted exponentiate

    As `depth` is fixed so we can make it as a static argument.
    This allows us to `jit` this function
    Args:
        input: shape (n, )
        depth: the depth of signature
    Return:
        A list of `jnp.ndarray` contains tensors
    """
    ret = [input]
    last = input
    for i in range(2, depth + 1):
        last = jnp.expand_dims(ret[-1], axis=-1) * input[None, :] / i
        ret += [last]
    return ret


@jax.jit
def mult_fused_restricted_exp(z: jnp.ndarray, A: List[jnp.ndarray]):
    """
    Multiply-fused-exponentiate

    Args:
        z: shape (n,)
        A: a list of `jnp.array` [(n, ), (n x n), (n x n x n), ...]
    Return:
        A list of which elements have the same shape with `A`
    """

    depth = len(A)

    ret = []
    for depth_index in range(depth):
        last = 1.0
        for i in range(depth_index + 1):
            current = addcmul(A[i], last, z=z / (depth_index + 1 - i))
            last = current
        ret.append(last)

    return ret


def addcmul(A, prev, z):
    return A + jnp.expand_dims(prev, axis=-1) * z[None, :]

In [2]:
@partial(jax.jit, static_argnames="depth")
def compute_signature(path, depth):

    path_increments = jnp.diff(path, axis=0)
    exp_term = restricted_exp(path_increments[0], depth=depth)

    def _body(i, val):
        ret = mult_fused_restricted_exp(path_increments[i], val)
        ret = [x.squeeze() for x in ret]
        return ret

    exp_term = jax.lax.fori_loop(
        lower=1,
        upper=path_increments.shape[0],
        body_fun=_body,
        init_val=exp_term,
    )

    return exp_term

In [3]:
def func(x):
    def _fn(x):
        sig = compute_signature(x, depth)
        return sum(map(jnp.sum, sig))

    def _fn_batch(batch_x):
        return jnp.sum(jax.vmap(_fn)(batch_x))

    value, grad = jax.value_and_grad(_fn_batch)(x)
    return value, grad

In [4]:
np.random.seed(0)

n_batch = 100
length = 100
dim = 10
x = np.random.rand(n_batch, n_batch, dim)

depth = 3

In [5]:
%%timeit
value, grad = func(x)
value.block_until_ready()
grad.block_until_ready()

14.5 ms ± 663 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
torch_x = torch.as_tensor(x).requires_grad_(True)

In [7]:
def torch_fun(x):
    sig = signatory.signature(x, depth)
    loss = torch.sum(sig)
    loss.backward()
    return loss

In [8]:
%%timeit
loss = torch_fun(torch_x)

28.3 ms ± 3.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
