## Benchmark between Signatory and Signax

This is just a rough comparison. 

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import signatory
import torch
from signax.signature import (  # noqa: E501
    signature,
    signature_combine,
    signature_to_logsignature,
)

### Gradient computation

In [2]:
np.random.seed(0)

n_batch = 100
length = 100
dim = 4
x = np.random.rand(n_batch, length, dim)

depth = 3

In [3]:
def func(x):
    def _fn(x):
        sig = signature(x, depth)
        return sum(map(jnp.sum, sig))

    def _fn_batch(batch_x):
        return jnp.sum(jax.vmap(_fn)(batch_x))

    value, grad = jax.value_and_grad(_fn_batch)(x)
    return value, grad


# Compile function
_ = func(x)

In [4]:
%%timeit
value, grad = func(x)
value.block_until_ready()
grad.block_until_ready()

13.7 ms ± 709 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
torch_x = torch.as_tensor(x).requires_grad_(True)

In [6]:
def torch_fun(x):
    sig = signatory.signature(x, depth)
    loss = torch.sum(sig)
    loss.backward()
    return loss

In [7]:
%%timeit
loss = torch_fun(torch_x)

10 ms ± 651 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
dim = 3

input_ = [
    jnp.ones((dim,)),
    jnp.ones((dim, dim)),
    jnp.ones((dim, dim, dim)),
]

output = signature_to_logsignature(input_)
print(output)

[DeviceArray([1., 1., 1.], dtype=float32), DeviceArray([0.5, 0.5, 0.5], dtype=float32), DeviceArray([0.33333337, 0.33333337, 0.33333337, 0.33333337, 0.33333337,
             0.33333337, 0.33333337, 0.33333337], dtype=float32)]


In [9]:
input_ = torch.ones((1, dim + dim * dim + dim * dim * dim))
log_signature = signatory.signature_to_logsignature(
    signature=input_, depth=3, channels=dim
)
print(log_signature)

tensor([[1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.5000, 0.3333, 0.3333, 0.3333,
         0.3333, 0.3333, 0.3333, 0.3333, 0.3333]])


In [10]:
x1 = np.random.rand(n_batch, length, dim)
x2 = np.hstack(
    (
        x1[:, -1:, :].copy(),
        np.random.randn(n_batch, length - 1, dim),
    )
)


def batch_signature(path):
    return jax.vmap(lambda x_in: signature(x_in, depth))(path)


sig_x1 = batch_signature(x1)
sig_x2 = batch_signature(x2)

In [11]:
def combine(signatures):
    return jax.vmap(signature_combine)(signatures)


# Compile function
_ = combine((sig_x1, sig_x2))

In [12]:
%%timeit
combine((sig_x1, sig_x2))

323 µs ± 23.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
torch_x1 = torch.as_tensor(x1.copy()).requires_grad_(True)
torch_x2 = torch.as_tensor(x2.copy()).requires_grad_(True)

sig_torch_x1 = signatory.signature(torch_x1, depth)
sig_torch_x2 = signatory.signature(torch_x2, depth)

sig_torch_x1.retain_grad()
sig_torch_x2.retain_grad()

In [14]:
def torch_combine(sig1, sig2):
    return signatory.signature_combine(sig1, sig2, dim, depth)


_ = torch_combine(sig_torch_x1, sig_torch_x2)

In [15]:
%%timeit
torch_combine(sig_torch_x1, sig_torch_x2)

62.4 µs ± 2.76 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
