In [1]:
from functools import partial
from typing import List

import jax
import jax.numpy as jnp
import signatory
import torch

In [87]:
@partial(jax.jit, static_argnums=2)
def mult_inner(A: List[jnp.ndarray], B: List[jnp.ndarray], depth_index: int):
    """
    Let `depth_index` = n

    this function returns
        $sum_{i=1}^n A_i x B_{n - i}$


    Note this is hard to convert to `jax.lax.fori_loop`.
    I don't know if it's possible. Several attempts but
    `TracerIntergerConversionError` is encountered because
    getting index of a list (it's okay to get index of ndarray
    but not for lists)
    """

    return sum(
        [
            A[i][..., None] * B[depth_index - i - 1][None, ...]
            for i in range(depth_index)
        ]
    )

In [88]:
@jax.jit
def mult(A: List[jnp.ndarray], B: List[jnp.ndarray]):

    depth = len(A)
    C = [a + b for a, b in zip(A, B)]
    for i in range(depth):
        C[i] += mult_inner(A, B, depth_index=i)

    return C

In [89]:
dim = 100

input1 = [jnp.ones((dim,)), jnp.ones((dim, dim)), jnp.ones((dim, dim, dim))]

In [68]:
%%timeit
output = mult_inner(input1, input1, depth_index=2)
output.block_until_ready()

316 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [95]:
%%timeit
output = mult(input1, input1)
[o.block_until_ready() for o in output]

261 µs ± 39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [96]:
def func(x):
    result = mult(x, x)
    return sum(map(jnp.sum, result))

In [101]:
%%timeit
value, grad = jax.value_and_grad(func)(input1)
value.block_until_ready()
[g.block_until_ready() for g in grad]

7.61 ms ± 799 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [99]:
x = torch.ones((1, dim + dim * dim + dim * dim * dim)).requires_grad_(True)

In [102]:
%%timeit
loss = signatory.signature_combine(x, x, input_channels=dim, depth=len(input1))
loss.sum().backward()

5.53 ms ± 1 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [103]:
loss.sum()

tensor(4030200., grad_fn=<SumBackward0>)

In [105]:
value

DeviceArray(5.0403e+08, dtype=float32)

In [90]:
output = mult(input1, input1)
output

[DeviceArray([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
              2., 2., 2., 2., 2., 2., 2., 2., 2., 2.], dtype=float32),
 DeviceArray([[3., 3., 3., ..., 3., 3., 3.],
              [3., 3., 3., ..., 3., 3., 3.],
              [3., 3., 3., ..., 3., 3., 3.],
              ...,
              [3., 3., 3., ..., 3., 3., 3.],
              [3., 3., 3., ..., 3., 3., 3.],
              [3., 3., 3., ..., 3., 3., 3.]], dtype=float32),
 DeviceArray([[[4., 4., 4., ..., 4., 4., 4.],
               [4., 4., 4., ..., 4., 4., 4.],
               [4., 4., 4., ..., 4., 4., 4.],
               ...,
               [4., 4

In [94]:
output[3].shape

IndexError: list index out of range

In [107]:
signatory.signature_combine(x, x, input_channels=dim, depth=len(input1)).sum()

tensor(4030200., grad_fn=<SumBackward0>)

In [110]:
value = func(input1)
value

DeviceArray(4030200., dtype=float32)