## "For" loops

In [117]:
import numpy as np
import numba

@numba.jit(nopython=True)
def numba_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):
        result += arr[i] ** 2
    return result

In [118]:
%%timeit
arr = np.random.rand(1000000)
numba_loops(arr)  # Very fast

11.2 ms ± 995 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [119]:
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from jax import jit

@jit
def jax_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [4]:
%%time
arr = jnp.array(np.random.rand(100))
print(jax_loops(arr))  # Not as fast as Numba

37.300316
CPU times: total: 93.8 ms
Wall time: 184 ms


In [5]:
def python_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [6]:
%%time
arr = jnp.array(np.random.rand(10**4))
print(python_loops(arr))  # Not as fast as Numba

3342.7043
CPU times: total: 1.12 s
Wall time: 1.44 s


Very bad results when using Python for and JAX

## ODE solver

#### JAX approach

In [7]:
import jax.numpy as jnp
import jax.lax as lax
from jax import jit

def f(x, t):  
    return -x  # Example: Simple exponential decay dx/dt = -x

def step(carry, t):
    x, h = carry
    x_new = x + h * f(x, t)  # Euler step
    return (x_new, h), x_new  # (carry, output)

In [8]:
@jit
def solve_euler(x0, h, t_array):
    carry = (x0, h)
    carry, x_values = lax.scan(step, carry, t_array)
    return x_values  # Solution for all timesteps

In [9]:
n_steps = 10**5

In [10]:
t_array = jnp.linspace(0, 10, n_steps)  # Time steps
h = t_array[1] - t_array[0]  # Step size
x0 = jnp.array(1.0)  # Initial condition

In [11]:
%%timeit

solution = solve_euler(x0, h, t_array)

157 µs ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### Naïve approach

In [12]:
def euler_naive(x0, h, t_array):
    x = np.zeros(len(t_array))
    x[0] = x0
    for i, t in enumerate(t_array[:-1]):
        x[i + 1] = x[i] + h * f(x[i], t)  # Euler update

    return np.array(x)

In [13]:
%%timeit

solution = euler_naive(x0, h, t_array)

3.1 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Numba

In [14]:
import numpy as np
import numba

@numba.njit
def solve_euler_numba(x0, h, t_array):
    n = t_array.shape[0]
    x_values = np.empty(n, dtype=np.float64)  
    x_values[0] = x0
    
    for i in range(n - 1):
        x_values[i + 1] = x_values[i] + h * (-x_values[i])  # dx/dt = -x

    return x_values

In [15]:
t_array_np = np.linspace(0, 10, n_steps)  # Time steps
h_np = t_array_np[1] - t_array_np[0]  # Step size
x0_np = np.array(1.0)  # Initial condition

In [16]:
%%timeit
solution = solve_euler_numba(x0_np, h_np, t_array_np)

196 µs ± 4.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Jax gives the best results for many timesteps when used wisely

## Creation of shuffle tables

In [None]:
table_trunc = 10
dim = 2

In [104]:
%%time
table_numba = ShuffleOperator(table_trunc, Alphabet(dim)).shuffle_table.T

CPU times: total: 2.3 s
Wall time: 2.93 s


In [105]:
table_numba

array([[   0,    0,    0,    1],
       [   0,    1,    1,    1],
       [   0,    2,    2,    1],
       ...,
       [2044,    0, 2044,    1],
       [2045,    0, 2045,    1],
       [2046,    0, 2046,    1]], dtype=int64)

In [111]:
%%time
table_jax = get_shuffle_table(table_trunc=table_trunc, dim=dim).T

CPU times: total: 2.08 s
Wall time: 2.67 s


In [113]:
table_jax

array([[   0,    0,    0,    1],
       [   0,    1,    1,    1],
       [   0,    2,    2,    1],
       ...,
       [2044,    0, 2044,    1],
       [2045,    0, 2045,    1],
       [2046,    0, 2046,    1]])

In [114]:
np.allclose(table_jax, table_numba)

True

## Shuffle product (when the shuffle table is compiled)

In [120]:
%load_ext autoreload
%autoreload 2
import sys
if "../.." not in sys.path:
    sys.path.append("../..")

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import json
from signature.signature.tensor_algebra import TensorAlgebra
from signature.signature.tensor_sequence import TensorSequence
from signature.signature.stationary_signature import stationary_signature_from_path, G
from signature.signature.expected_signature import expected_signature, expected_stationary_signature
from scipy.optimize import minimize

green = "#5b6c64"
copper = "#B56246"
plt.rcParams["figure.figsize"]        = ((1 + np.sqrt(5)) / 2 * 5, 5)
plt.rcParams["figure.autolayout"]     = True
plt.rcParams["patch.force_edgecolor"] = False
plt.rcParams["axes.grid"]             = True
plt.rcParams['axes.prop_cycle']       = matplotlib.cycler(color=[green, copper, "#322B4D", "#28BC9C", "#71074E"])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Numba

In [121]:
%%time

table_trunc = 10
ts_trunc = 5

ta = TensorAlgebra(dim=2, trunc=table_trunc)

Compiling...
Compilation finished.
CPU times: total: 5min 19s
Wall time: 6min 40s


In [126]:
rng = np.random.default_rng(seed=42)
array = rng.random(size=ta.alphabet.number_of_elements(ts_trunc))

In [127]:
ts = ta.from_array(trunc=table_trunc, array=array)

In [261]:
ts_2D = ta.from_array(trunc=table_trunc, array=ts_arr_2d)

In [265]:
%%timeit
res = ta.shuop.shuffle_prod(ts, ts)

14.2 ms ± 3.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [264]:
ta.print(ts_2D)

[0.77395605 0.77395605 0.77395605 0.77395605 0.77395605 0.77395605
 0.77395605 0.77395605 0.77395605 0.77395605]*∅ + [0.43887844 0.43887844 0.43887844 0.43887844 0.43887844 0.43887844
 0.43887844 0.43887844 0.43887844 0.43887844]*1 + [0.85859792 0.85859792 0.85859792 0.85859792 0.85859792 0.85859792
 0.85859792 0.85859792 0.85859792 0.85859792]*2 + [0.69736803 0.69736803 0.69736803 0.69736803 0.69736803 0.69736803
 0.69736803 0.69736803 0.69736803 0.69736803]*11 + [0.09417735 0.09417735 0.09417735 0.09417735 0.09417735 0.09417735
 0.09417735 0.09417735 0.09417735 0.09417735]*12 + [0.97562235 0.97562235 0.97562235 0.97562235 0.97562235 0.97562235
 0.97562235 0.97562235 0.97562235 0.97562235]*21 + [0.7611397 0.7611397 0.7611397 0.7611397 0.7611397 0.7611397 0.7611397
 0.7611397 0.7611397 0.7611397]*22 + [0.78606431 0.78606431 0.78606431 0.78606431 0.78606431 0.78606431
 0.78606431 0.78606431 0.78606431 0.78606431]*111 + [0.12811363 0.12811363 0.12811363 0.12811363 0.12811363 0.12811363
 

In [271]:
%%timeit
ta.shuop.shuffle_prod_2d(ts_2D, ts_2D).shape

76.9 ms ± 3.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [134]:
res = ta.shuop.shuffle_prod(ts, ts)
res.array.squeeze().real

array([ 0.59900797,  0.67934525,  1.32903411, ..., 79.59349657,
       67.10270751, 80.51194773])

#### Pure numpy

In [137]:
def __get_extended_array_np(ts: TensorSequence):
    n_elements = ta.alphabet.number_of_elements(table_trunc)

    new_array = np.zeros((n_elements,) + ts.shape[1:], dtype=complex)
    new_array[:min(n_elements, ts.shape[0])] = ts.array[:min(n_elements, ts.shape[0])]
    return new_array

def shuffle_prod_np(
    ts1: TensorSequence,
    ts2: TensorSequence,
    shuffle_table
):
    index_left, index_right, index_result, count = shuffle_table

    if ts1.trunc < table_trunc:
        array_1 = __get_extended_array_np(ts1)
    else:
        array_1 = ts1.array

    if ts2.trunc < table_trunc:
        array_2 = __get_extended_array_np(ts2)
    else:
        array_2 = ts2.array

    source = count * array_1[index_left, 0, 0] * array_2[index_right, 0, 0]
    linear_result = np.zeros(index_result[-1] + 1, dtype=complex)
    for i in range(len(index_result)):
        linear_result[index_result[i]] = linear_result[index_result[i]] + source[i]
    return TensorSequence(ta.alphabet, table_trunc, linear_result)

In [139]:
%%timeit
res2 = shuffle_prod_np(ts, ts, ta.shuop.shuffle_table)

145 ms ± 6.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [140]:
res2 = shuffle_prod_np(ts, ts, ta.shuop.shuffle_table)
ta.print(res2 - res)




In [251]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial

from signature.signature.jax_signature.words import number_of_words_up_to_trunc

def change_trunc(ts: jax.Array, trunc: int, dim: int = 2):
    n_elements = number_of_words_up_to_trunc(trunc, dim)
    new_array = jnp.zeros(n_elements)
    new_array = new_array.at[:min(n_elements, ts.shape[0])].set(ts[:min(n_elements, ts.shape[0])])
    return new_array

@jax.jit
def shuffle_prod_jax(
    ts1: jax.Array,
    ts2: jax.Array,
    shuffle_table: jax.Array,
):
    index_left, index_right, index_result, count = shuffle_table

    source = count * ts1[index_left] * ts2[index_right]
    linear_result = ts1 * 0
    #jnp.zeros(number_of_words_up_to_trunc(trunc)) # index_result[-1] + 1
    linear_result = linear_result.at[index_result].add(source)
    
    return linear_result

shuffle_prod_jax_vect = jax.jit(jax.vmap(shuffle_prod_jax, in_axes=(1, 1, None), out_axes=1))

In [233]:
ts_arr = jnp.array(ts.array.squeeze()).real
ts_arr

Array([0.77395605, 0.43887844, 0.85859792, ..., 0.        , 0.        ,
       0.        ], dtype=float64)

In [248]:
ts_arr_2d = jnp.vstack([ts_arr] * 10).T

In [151]:
shuffle_table = get_shuffle_table(table_trunc=table_trunc)

In [215]:
%%time

change_trunc(ts=ts_arr, trunc=10)

CPU times: total: 0 ns
Wall time: 0 ns


Array([0.77395605, 0.43887844, 0.85859792, ..., 0.        , 0.        ,
       0.        ], dtype=float64)

In [229]:
%%timeit
ts_arr_2 = change_trunc(ts=ts_arr, trunc=10)

915 µs ± 90.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [240]:
%%timeit
ts_arr_2 = change_trunc(ts=ts_arr, trunc=10)
ts_arr_2 = change_trunc(ts=ts_arr, trunc=10)
shuffle_prod_jax(ts_arr, ts_arr, shuffle_table=shuffle_table)

5.17 ms ± 381 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [206]:
res.array.squeeze().real

array([ 0.59900797,  0.67934525,  1.32903411, ..., 79.59349657,
       67.10270751, 80.51194773])

In [252]:
shuffle_prod_jax(change_trunc(ts_arr, 3), ts_arr, jnp.array(shuffle_table))

Array([0.59900797, 0.67934525, 1.32903411, 1.46469298, 0.89941849,
       2.26381787, 2.65255813, 3.0531172 , 1.56115575, 2.83369937,
       1.56550026, 4.34478668, 3.50179781, 5.29232755, 4.60744399],      dtype=float64)

In [244]:
shuffle_prod_jax(ts_arr, ts_arr, jnp.array(shuffle_table))[:15]

Array([0.59900797, 0.67934525, 1.32903411, 1.46469298, 0.89941849,
       2.26381787, 2.65255813, 3.0531172 , 1.56115575, 2.83369937,
       1.56550026, 4.34478668, 3.50179781, 5.29232755, 4.60744399],      dtype=float64)

In [255]:
shuffle_prod_jax_vect(ts_arr_2d, ts_arr_2d, shuffle_table)

Array([[ 0.59900797,  0.59900797,  0.59900797, ...,  0.59900797,
         0.59900797,  0.59900797],
       [ 0.67934525,  0.67934525,  0.67934525, ...,  0.67934525,
         0.67934525,  0.67934525],
       [ 1.32903411,  1.32903411,  1.32903411, ...,  1.32903411,
         1.32903411,  1.32903411],
       ...,
       [79.59349657, 79.59349657, 79.59349657, ..., 79.59349657,
        79.59349657, 79.59349657],
       [67.10270751, 67.10270751, 67.10270751, ..., 67.10270751,
        67.10270751, 67.10270751],
       [80.51194773, 80.51194773, 80.51194773, ..., 80.51194773,
        80.51194773, 80.51194773]], dtype=float64)

In [256]:
%%timeit
shuffle_prod_jax_vect(ts_arr_2d, ts_arr_2d, shuffle_table)

22.9 ms ± 2.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [321]:
import jax_dataclasses as jdc

@jdc.pytree_dataclass
class TensorSequenceJAX:
    array: jax.Array
    trunc: int
    dim: int
    
    #def __init__(self)
    
    def __repr__(self):
        return str(self.array)
    
    @jax.jit
    def foo(self):
        return self.array * 2

In [322]:
ts_jax = TensorSequenceJAX(array=ts_arr, trunc=table_trunc, dim=dim)

In [323]:
ts_jax.foo()

Array([1.5479121 , 0.87775688, 1.71719584, ..., 0.        , 0.        ,
       0.        ], dtype=float64)

In [290]:
a = A(a=jnp.ones(5x), b=False)

In [305]:
@jax.jit
def foo(a: A):
    return a.b * a.a

In [292]:
foo(a)

Array([0., 0., 0., 0., 0.], dtype=float64)

In [281]:
a.a = jnp.ones(5)

FrozenInstanceError: Dataclass registered as pytree is immutable!