In [10]:
import jax
import jax.numpy as jnp
import numpy as np
import signatory
import torch
from signax.signature import signature, signature_to_logsignature

In [3]:
np.random.seed(0)

n_batch = 100
length = 100
dim = 10
x = np.random.rand(n_batch, n_batch, dim)

depth = 3

In [6]:
def func(x):
    def _fn(x):
        sig = signature(x, depth)
        return sum(map(jnp.sum, sig))

    def _fn_batch(batch_x):
        return jnp.sum(jax.vmap(_fn)(batch_x))

    value, grad = jax.value_and_grad(_fn_batch)(x)
    return value, grad

In [8]:
%%timeit
value, grad = func(x)
value.block_until_ready()
grad.block_until_ready()

25.5 ms ± 2.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
torch_x = torch.as_tensor(x).requires_grad_(True)

In [12]:
def torch_fun(x):
    sig = signatory.signature(x, depth)
    loss = torch.sum(sig)
    loss.backward()
    return loss

In [13]:
%%timeit
loss = torch_fun(torch_x)

74.5 ms ± 26.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
dim = 3

input = [
    jnp.ones((dim,)),
    jnp.ones((dim, dim)),
    jnp.ones((dim, dim, dim)),
]

output = signature_to_logsignature(input)
print(output)

[DeviceArray([1., 1., 1.], dtype=float32), DeviceArray([0.5, 0.5, 0.5], dtype=float32), DeviceArray([0.33333337, 0.33333337, 0.33333337, 0.33333337, 0.33333337,
             0.33333337, 0.33333337, 0.33333337], dtype=float32)]


In [15]:
input = torch.ones((1, dim + dim * dim + dim * dim * dim))
log_signature = signatory.signature_to_logsignature(
    signature=input, depth=3, channels=dim
)
print(log_signature)

tensor([[1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.5000, 0.3333, 0.3333, 0.3333,
         0.3333, 0.3333, 0.3333, 0.3333, 0.3333]])
