## "For" loops

In [1]:
import numpy as np
import numba

@numba.jit(nopython=True)
def numba_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):
        result += arr[i] ** 2
    return result

In [118]:
%%timeit
arr = np.random.rand(1000000)
numba_loops(arr)  # Very fast

11.2 ms ± 995 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [119]:
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from jax import jit

@jit
def jax_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [4]:
%%time
arr = jnp.array(np.random.rand(100))
print(jax_loops(arr))  # Not as fast as Numba

37.300316
CPU times: total: 93.8 ms
Wall time: 184 ms


In [5]:
def python_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [6]:
%%time
arr = jnp.array(np.random.rand(10**4))
print(python_loops(arr))  # Not as fast as Numba

3342.7043
CPU times: total: 1.12 s
Wall time: 1.44 s


Very bad results when using Python for and JAX

## ODE solver

#### JAX approach

In [7]:
import jax.numpy as jnp
import jax.lax as lax
from jax import jit

def f(x, t):  
    return -x  # Example: Simple exponential decay dx/dt = -x

def step(carry, t):
    x, h = carry
    x_new = x + h * f(x, t)  # Euler step
    return (x_new, h), x_new  # (carry, output)

In [8]:
@jit
def solve_euler(x0, h, t_array):
    carry = (x0, h)
    carry, x_values = lax.scan(step, carry, t_array)
    return x_values  # Solution for all timesteps

In [9]:
n_steps = 10**5

In [10]:
t_array = jnp.linspace(0, 10, n_steps)  # Time steps
h = t_array[1] - t_array[0]  # Step size
x0 = jnp.array(1.0)  # Initial condition

In [11]:
%%timeit

solution = solve_euler(x0, h, t_array)

157 µs ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### Naïve approach

In [12]:
def euler_naive(x0, h, t_array):
    x = np.zeros(len(t_array))
    x[0] = x0
    for i, t in enumerate(t_array[:-1]):
        x[i + 1] = x[i] + h * f(x[i], t)  # Euler update

    return np.array(x)

In [13]:
%%timeit

solution = euler_naive(x0, h, t_array)

3.1 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Numba

In [14]:
import numpy as np
import numba

@numba.njit
def solve_euler_numba(x0, h, t_array):
    n = t_array.shape[0]
    x_values = np.empty(n, dtype=np.float64)  
    x_values[0] = x0
    
    for i in range(n - 1):
        x_values[i + 1] = x_values[i] + h * (-x_values[i])  # dx/dt = -x

    return x_values

In [15]:
t_array_np = np.linspace(0, 10, n_steps)  # Time steps
h_np = t_array_np[1] - t_array_np[0]  # Step size
x0_np = np.array(1.0)  # Initial condition

In [16]:
%%timeit
solution = solve_euler_numba(x0_np, h_np, t_array_np)

196 µs ± 4.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Jax gives the best results for many timesteps when used wisely

## Creation of shuffle tables

In [21]:
from signature.signature.shuffle_operator import ShuffleOperator
from signature.signature.alphabet import Alphabet
from signature.signature.jax_signature.shuffle_table import get_shuffle_table

In [22]:
table_trunc = 10
dim = 2

In [23]:
%%time
table_numba = ShuffleOperator(table_trunc, Alphabet(dim)).shuffle_table.T

CPU times: total: 1.88 s
Wall time: 2.12 s


In [24]:
table_numba

array([[   0,    0,    0,    1],
       [   0,    1,    1,    1],
       [   0,    2,    2,    1],
       ...,
       [2044,    0, 2044,    1],
       [2045,    0, 2045,    1],
       [2046,    0, 2046,    1]], dtype=int64)

In [28]:
%%time
table_jax = get_shuffle_table(table_trunc=table_trunc, dim=dim).T

CPU times: total: 1.81 s
Wall time: 2.18 s


In [29]:
table_jax

array([[   0,    0,    0,    1],
       [   0,    1,    1,    1],
       [   0,    2,    2,    1],
       ...,
       [2044,    0, 2044,    1],
       [2045,    0, 2045,    1],
       [2046,    0, 2046,    1]])

In [30]:
np.allclose(table_jax, table_numba)

True

## Shuffle product (when the shuffle table is compiled)

In [31]:
%load_ext autoreload
%autoreload 2
import sys
if "../.." not in sys.path:
    sys.path.append("../..")

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import json
from signature.signature.tensor_algebra import TensorAlgebra
from signature.signature.tensor_sequence import TensorSequence
from signature.signature.stationary_signature import stationary_signature_from_path, G
from signature.signature.expected_signature import expected_signature, expected_stationary_signature
from scipy.optimize import minimize

green = "#5b6c64"
copper = "#B56246"
plt.rcParams["figure.figsize"]        = ((1 + np.sqrt(5)) / 2 * 5, 5)
plt.rcParams["figure.autolayout"]     = True
plt.rcParams["patch.force_edgecolor"] = False
plt.rcParams["axes.grid"]             = True
plt.rcParams['axes.prop_cycle']       = matplotlib.cycler(color=[green, copper, "#322B4D", "#28BC9C", "#71074E"])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Numba

In [32]:
%%time

table_trunc = 10
ts_trunc = 5

ta = TensorAlgebra(dim=2, trunc=table_trunc)

Compiling...
Compilation finished.
CPU times: total: 1.38 s
Wall time: 2.26 s


In [33]:
rng = np.random.default_rng(seed=42)
array = rng.random(size=ta.alphabet.number_of_elements(ts_trunc))

In [34]:
ts = ta.from_array(trunc=table_trunc, array=array)

In [35]:
%%timeit
res = ta.shuop.shuffle_prod(ts, ts)

4.49 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [36]:
res = ta.shuop.shuffle_prod(ts, ts)
res.array.squeeze().real

array([ 0.59900797,  0.67934525,  1.32903411, ..., 79.59349657,
       67.10270751, 80.51194773])

#### Pure numpy

In [37]:
def __get_extended_array_np(ts: TensorSequence):
    n_elements = ta.alphabet.number_of_elements(table_trunc)

    new_array = np.zeros((n_elements,) + ts.shape[1:], dtype=complex)
    new_array[:min(n_elements, ts.shape[0])] = ts.array[:min(n_elements, ts.shape[0])]
    return new_array

def shuffle_prod_np(
    ts1: TensorSequence,
    ts2: TensorSequence,
    shuffle_table
):
    index_left, index_right, index_result, count = shuffle_table

    if ts1.trunc < table_trunc:
        array_1 = __get_extended_array_np(ts1)
    else:
        array_1 = ts1.array

    if ts2.trunc < table_trunc:
        array_2 = __get_extended_array_np(ts2)
    else:
        array_2 = ts2.array

    source = count * array_1[index_left, 0, 0] * array_2[index_right, 0, 0]
    linear_result = np.zeros(index_result[-1] + 1, dtype=complex)
    for i in range(len(index_result)):
        linear_result[index_result[i]] = linear_result[index_result[i]] + source[i]
    return TensorSequence(ta.alphabet, table_trunc, linear_result)

In [38]:
%%timeit
res2 = shuffle_prod_np(ts, ts, ta.shuop.shuffle_table)

112 ms ± 9.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [39]:
res2 = shuffle_prod_np(ts, ts, ta.shuop.shuffle_table)
ta.print(res2 - res)




In [40]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial

from signature.signature.jax_signature.words import number_of_words_up_to_trunc

def change_trunc(ts: jax.Array, trunc: int, dim: int = 2):
    n_elements = number_of_words_up_to_trunc(trunc, dim)
    new_array = jnp.zeros(n_elements)
    new_array = new_array.at[:min(n_elements, ts.shape[0])].set(ts[:min(n_elements, ts.shape[0])])
    return new_array

@jax.jit
def shuffle_prod_jax(
    ts1: jax.Array,
    ts2: jax.Array,
    shuffle_table: jax.Array,
):
    index_left, index_right, index_result, count = shuffle_table

    source = count * ts1[index_left] * ts2[index_right]
    linear_result = ts1 * 0
    #jnp.zeros(number_of_words_up_to_trunc(trunc)) # index_result[-1] + 1
    linear_result = linear_result.at[index_result].add(source)
    
    return linear_result

shuffle_prod_jax_vect = jax.jit(jax.vmap(shuffle_prod_jax, in_axes=(1, 1, None), out_axes=1))

In [41]:
ts_arr = jnp.array(ts.array.squeeze()).real
ts_arr

Array([0.77395605, 0.43887844, 0.85859792, ..., 0.        , 0.        ,
       0.        ], dtype=float64)

In [42]:
ts_arr_2d = jnp.vstack([ts_arr] * 10).T

In [43]:
shuffle_table = get_shuffle_table(table_trunc=table_trunc)

In [122]:
%%timeit
ts_arr_2 = change_trunc(ts=ts_arr, trunc=10)
ts_arr_2 = change_trunc(ts=ts_arr, trunc=10)
shuffle_prod_jax(ts_arr, ts_arr, shuffle_table=shuffle_table)

4.02 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [123]:
from signature.signature.jax_signature.tensor_sequence_jax import TensorSequenceJAX

In [203]:
ts_arr_2d

Array([[0.77395605, 0.77395605, 0.77395605, ..., 0.77395605, 0.77395605,
        0.77395605],
       [0.43887844, 0.43887844, 0.43887844, ..., 0.43887844, 0.43887844,
        0.43887844],
       [0.85859792, 0.85859792, 0.85859792, ..., 0.85859792, 0.85859792,
        0.85859792],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], dtype=float64)

In [243]:
ts_jax = TensorSequenceJAX(array=ts_arr, trunc=table_trunc, dim=dim)

In [244]:
print(ts_jax)

0.7739560485559633*0 + 0.4388784397520523*1 + 0.8585979199113825*2 + 0.6973680290593639*11 + 0.09417734788764953*12 + 0.9756223516367559*21 + 0.761139701990353*22 + 0.7860643052769538*111 + 0.12811363267554587*112 + 0.45038593789556713*121 + 0.37079802423258124*122 + 0.9267649888486018*211 + 0.6438651200806645*212 + 0.82276161327083*221 + 0.44341419882733113*222 + 0.2272387217847769*1111 + 0.5545847870158348*1112 + 0.06381725610417532*1121 + 0.8276311719925821*1122 + 0.6316643991220648*1211 + 0.7580877400853738*1212 + 0.35452596812986836*1221 + 0.9706980243949033*1222 + 0.8931211213221977*2111 + 0.7783834970737619*2112 + 0.19463870785196757*2121 + 0.4667210037270342*2122 + 0.04380376578722878*2211 + 0.15428949206754783*2212 + 0.6830489532424546*2221 + 0.7447621559078171*2222 + 0.96750973243421*11111 + 0.32582535813815194*11112 + 0.3704597060348689*11121 + 0.4695558112758079*11122 + 0.1894713590842857*11211 + 0.12992150533547164*11212 + 0.47570492622593374*11221 + 0.2269093490508841*112

In [245]:
print(ts_jax.proj(1))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[2047])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [225]:
@jax.jit
def bar(ts: TensorSequenceJAX):
    return ts.dim

In [206]:
bar(ts_jax)

Array(2, dtype=int64, weak_type=True)

In [207]:
ts_jax @ ts_jax

Array([21.94641586, 21.94641586, 21.94641586, 21.94641586, 21.94641586,
       21.94641586, 21.94641586, 21.94641586, 21.94641586, 21.94641586],      dtype=float64)

In [208]:
ts_jax.plot(trunc=3)

ValueError: could not broadcast input array from shape (15,10) into shape (15,)

In [219]:
ts_jax.shape

(2047, 10)

In [290]:
a = A(a=jnp.ones(5x), b=False)

In [305]:
@jax.jit
def foo(a: A):
    return a.b * a.a

In [292]:
foo(a)

Array([0., 0., 0., 0., 0.], dtype=float64)

In [169]:
import jax.numpy as jnp

A = jnp.ones((1000, 300, 400)) * 2  # Example shape
B = jnp.ones((1000, 300, 400))

@jax.jit
def matmul_1(A, B):
    result = jnp.einsum('i...,i...->...', A, B)  # Summing over the first axis
    return result

# @jax.jit
def matmul_2(A, B):
    result = np.sum(A * B, axis=0)
    return result

In [170]:
matmul_2(A, B)

Array([[2000., 2000., 2000., ..., 2000., 2000., 2000.],
       [2000., 2000., 2000., ..., 2000., 2000., 2000.],
       [2000., 2000., 2000., ..., 2000., 2000., 2000.],
       ...,
       [2000., 2000., 2000., ..., 2000., 2000., 2000.],
       [2000., 2000., 2000., ..., 2000., 2000., 2000.],
       [2000., 2000., 2000., ..., 2000., 2000., 2000.]], dtype=float64)

In [171]:
%%timeit
matmul_2(A, B)

170 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [189]:
A[:, 1:10, 1]

Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float64)

In [222]:
import signature.signature.jax_signature.words as alpha
import signature.signature.jax_signature.tensor_algebra_jax as taj

In [223]:
taj.word_to_base_dim_number(123, 2)

Array(4, dtype=int64)

In [195]:
a = (slice(1, 10), 1)
A[(slice(None), *a)]

Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float64)

In [191]:
ts_jax[1:]

Array([0.43887844, 0.85859792, 0.69736803, ..., 0.        , 0.        ,
       0.        ], dtype=float64)

In [178]:
A.shape

(1000, 300, 400)