In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import signatory
import torch

from signax.signature import signature, signature_to_logsignature, signature_combine

In [2]:
np.random.seed(0)

n_batch = 100
length = 100
dim = 4
x = np.random.rand(n_batch, length, dim)

depth = 3

In [3]:
def jax_grad_over_sum(_callable):
    _callable = jax.jit(jax.grad(_callable))
    def wrap(*args, **kwargs):
        def _batch_call(_args, _kwargs):
            return jnp.sum(jax.vmap(_callable)(*_args, **_kwargs))

        return jax.value_and_grad(_batch_call)(args, kwargs)

    return wrap

In [4]:
@jax_grad_over_sum
def signature_loss(path):
    sig = signature(path, depth)
    return sum(map(jnp.sum, sig))

# Compile function
_ = signature_loss(x)



In [5]:
%%timeit
value, (grad,) = signature_loss(x)
value.block_until_ready()
grad.block_until_ready()

5.55 ms ± 731 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
torch_x = torch.as_tensor(x).requires_grad_(True)

In [7]:
def torch_fun(x):
    sig = signatory.signature(x, depth)
    loss = torch.sum(sig)
    loss.backward()
    return loss

In [8]:
%%timeit
loss = torch_fun(torch_x)

4.28 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
dim = 3

input_ = [
    jnp.ones((dim,)),
    jnp.ones((dim, dim)),
    jnp.ones((dim, dim, dim)),
]

output = signature_to_logsignature(input_)
print(output)

[DeviceArray([1., 1., 1.], dtype=float32), DeviceArray([0.5, 0.5, 0.5], dtype=float32), DeviceArray([0.33333337, 0.33333337, 0.33333337, 0.33333337, 0.33333337,
             0.33333337, 0.33333337, 0.33333337], dtype=float32)]


In [10]:
input_ = torch.ones((1, dim + dim * dim + dim * dim * dim))
log_signature = signatory.signature_to_logsignature(
    signature=input_, depth=3, channels=dim
)
print(log_signature)

tensor([[1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.5000, 0.3333, 0.3333, 0.3333,
         0.3333, 0.3333, 0.3333, 0.3333, 0.3333]])


In [11]:
x1 = np.random.rand(n_batch, length, dim)
x2 = np.hstack((
    x1[:, -1:, :].copy(),
    np.random.randn(n_batch, length - 1, dim)
))


def batch_signature(path):
    return jax.vmap(lambda x_in: signature(x_in, depth))(path)


sig_x1 = batch_signature(x1)
sig_x2 = batch_signature(x2)

In [12]:
def combine(signature1, signature2):
    return jax.vmap(signature_combine)(signature1, signature2)


# Compile function
_ = combine(sig_x1, sig_x2)

In [13]:
%%timeit
combine(sig_x1, sig_x2)

305 µs ± 1.99 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
torch_x1 = torch.as_tensor(x1.copy()).requires_grad_(True)
torch_x2 = torch.as_tensor(x2.copy()).requires_grad_(True)

sig_torch_x1 = signatory.signature(torch_x1, depth)
sig_torch_x2 = signatory.signature(torch_x2, depth)

sig_torch_x1.retain_grad()
sig_torch_x2.retain_grad()

In [15]:
def torch_combine(sig1, sig2):
    return signatory.signature_combine(sig1, sig2, dim, depth)


_ = torch_combine(sig_torch_x1, sig_torch_x2)

In [16]:
%%timeit
torch_combine(sig_torch_x1, sig_torch_x2)

60.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
