In [232]:
import jax.numpy as jnp
from functools import partial
from jax.lax import scan
import jax
from jax import random
import numpy as np

In [92]:
key = random.PRNGKey(42)
f = lambda x: jnp.exp(-jnp.sum(x**2, axis=1))[:, None]
n_samples = 10
x = random.normal(key, (n_samples, 2))

Let's get a baseline going:

In [93]:
_, vjp = jax.vjp(f, x)

In [38]:
vjp(jnp.ones((n_samples, 1)))

(DeviceArray([[ 0.8283536 , -0.22070852],
              [ 0.15463774,  0.8198199 ],
              [ 0.00879653, -0.05468073],
              [-0.06637478,  0.01172323],
              [ 0.73979336, -0.17955527],
              [ 0.48355308, -0.39451557],
              [-0.21793498,  0.55520076],
              [-0.09029141,  0.20403846],
              [ 0.05831185, -0.8185504 ],
              [ 0.65053004, -0.40385708]], dtype=float32),)

In [45]:
v = jnp.zeros((n_samples, 2))
v = (jax.ops.index_update(v, jax.ops.index[:, 1], 1), )

In [46]:
vjp2 = jax.vjp(vjp, jnp.ones((n_samples, 1)))[1]

In [47]:
vjp2(v)

(DeviceArray([[-0.22070852],
              [ 0.8198199 ],
              [-0.05468073],
              [ 0.01172323],
              [-0.17955527],
              [-0.39451557],
              [ 0.55520076],
              [ 0.20403846],
              [-0.8185504 ],
              [-0.40385708]], dtype=float32),)

In [49]:
jax.vjp(vjp2, v)[1](v)

((DeviceArray([[ 0.8283536 , -0.22070852],
               [ 0.15463774,  0.8198199 ],
               [ 0.00879653, -0.05468073],
               [-0.06637478,  0.01172323],
               [ 0.73979336, -0.17955527],
               [ 0.48355308, -0.39451557],
               [-0.21793498,  0.55520076],
               [-0.09029141,  0.20403846],
               [ 0.05831185, -0.8185504 ],
               [ 0.65053004, -0.40385708]], dtype=float32),),)

Alright let's try and get that prettier.

In [71]:
v_first = jnp.ones((n_samples, 1))
v_higher = jax.ops.index_update(jnp.zeros((n_samples, 2)), jax.ops.index[:, 1], 1)#, ) # tuple because of jvp structure

In [86]:
%%time
f_vals, vjp = jax.vjp(f, x)

df_vals, vjp = jax.vjp(lambda v: vjp(v)[0], v_first)

d2f_vals, vjp = jax.vjp(lambda v: vjp(v)[0], v_higher)

d3f_vals, vjp = jax.vjp(lambda v: vjp(v)[0], v_higher)

d4f_vals, vjp = jax.vjp(lambda v: vjp(v)[0], v_higher)


CPU times: user 11 ms, sys: 6.95 ms, total: 18 ms
Wall time: 14.2 ms


In [91]:
d2f_vals

DeviceArray([[-0.22070852],
             [ 0.8198199 ],
             [-0.05468073],
             [ 0.01172323],
             [-0.17955527],
             [-0.39451557],
             [ 0.55520076],
             [ 0.20403846],
             [-0.8185504 ],
             [-0.40385708]], dtype=float32)

In [57]:
df_vals

(DeviceArray([[ 0.8283536 , -0.22070852],
              [ 0.15463774,  0.8198199 ],
              [ 0.00879653, -0.05468073],
              [-0.06637478,  0.01172323],
              [ 0.73979336, -0.17955527],
              [ 0.48355308, -0.39451557],
              [-0.21793498,  0.55520076],
              [-0.09029141,  0.20403846],
              [ 0.05831185, -0.8185504 ],
              [ 0.65053004, -0.40385708]], dtype=float32),)

In [58]:
d2f_vals

(DeviceArray([[-0.22070852],
              [ 0.8198199 ],
              [-0.05468073],
              [ 0.01172323],
              [-0.17955527],
              [-0.39451557],
              [ 0.55520076],
              [ 0.20403846],
              [-0.8185504 ],
              [-0.40385708]], dtype=float32),)

In [59]:
d3f_vals

((DeviceArray([[ 0.8283536 , -0.22070852],
               [ 0.15463774,  0.8198199 ],
               [ 0.00879653, -0.05468073],
               [-0.06637478,  0.01172323],
               [ 0.73979336, -0.17955527],
               [ 0.48355308, -0.39451557],
               [-0.21793498,  0.55520076],
               [-0.09029141,  0.20403846],
               [ 0.05831185, -0.8185504 ],
               [ 0.65053004, -0.40385708]], dtype=float32),),)

Wont work cause im taking the derivative w.r.t the tangent and not the input x

In [94]:
f_vals, vjp = jax.vjp(f, x)

In [181]:
def vjp(f, v):
    def _vjp(x):
        return jax.vjp(f, x)[1](v)[0]
    return _vjp

In [182]:
first_vjp = vjp(f, v_first)

In [183]:
first_vjp(x)

DeviceArray([[ 0.8283536 , -0.22070852],
             [ 0.15463774,  0.8198199 ],
             [ 0.00879653, -0.05468073],
             [-0.06637478,  0.01172323],
             [ 0.73979336, -0.17955527],
             [ 0.48355308, -0.39451557],
             [-0.21793498,  0.55520076],
             [-0.09029141,  0.20403846],
             [ 0.05831185, -0.8185504 ],
             [ 0.65053004, -0.40385708]], dtype=float32)

In [184]:
- 2 * x * f(x)

DeviceArray([[ 0.8283536 , -0.22070852],
             [ 0.15463774,  0.8198199 ],
             [ 0.00879653, -0.05468073],
             [-0.06637478,  0.01172323],
             [ 0.73979336, -0.17955527],
             [ 0.48355308, -0.39451557],
             [-0.21793498,  0.55520076],
             [-0.09029141,  0.20403846],
             [ 0.05831185, -0.8185504 ],
             [ 0.65053004, -0.40385708]], dtype=float32)

In [185]:
second_vjp = vjp(first_vjp, v_higher)
third_vjp = vjp(second_vjp, v_higher)

DeviceArray([[-0.02359307, -1.1009936 ],
             [-0.9600828 ,  0.32658648],
             [-0.02084618,  0.19771393],
             [ 0.2312356 , -0.02505486],
             [-0.90232   , -1.5499763 ],
             [-1.49998   , -1.5885731 ],
             [-0.32423532,  0.69759923],
             [-0.01630688,  0.47565442],
             [-0.9462608 ,  0.45216346],
             [-1.0477774 , -1.3764659 ]], dtype=float32)

In [172]:
second_vjp(x)

DeviceArray([[-0.3090017 , -1.1009936 ],
             [ 0.2516502 ,  0.32658648],
             [-0.03609403,  0.19771393],
             [-0.04672404, -0.02505486],
             [-0.16703224, -1.5499763 ],
             [-0.21617593, -1.5885731 ],
             [-0.474165  ,  0.6975991 ],
             [-0.27071664,  0.47565442],
             [-0.10012902,  0.45216346],
             [-0.33201522, -1.3764659 ]], dtype=float32)

In [167]:
third_vjp(x)

DeviceArray([[-1.5414399 ,  1.293539  ],
             [ 0.10024831, -2.747808  ],
             [ 0.13050836, -0.592539  ],
             [ 0.09985851, -0.0645301 ],
             [-1.4418739 ,  1.0681783 ],
             [-0.87046313,  2.2882454 ],
             [-0.5957793 , -0.703024  ],
             [-0.6310946 ,  0.6099791 ],
             [ 0.0553108 ,  2.4977784 ],
             [-1.1316073 ,  2.317944  ]], dtype=float32)

In [206]:
def vjp(f, v):
    def _vjp(x):
        return jax.vjp(f, x)[1](v)[0]
    return _vjp

In [221]:
v_first = 
v_higher = 


order_vjp = vjp(f, v_first)
df = [order_vjp(x)]
deriv_order=3
for _ in jnp.arange(deriv_order - 1):
    order_vjp = vjp(order_vjp, v_higher)
    df.append(order_vjp(x))

In [275]:
@partial(jax.jit, static_argnums=(0, 2))
def nth_grad(f, x, order, prop_idx):
    """ Calculates gradient up to n-th order of f w.r.t input_idx column of x.
    prop_column is the propagated column, meaning for higher order stuff the deriv is 
    [dx^n, dx^(n-1)dym dx^(n-1)dz, ...], Returns tensor with shapes [n_samples, n_inputs, order]"""
    def vjp(f, v):
        def _vjp(x):
            return jax.vjp(f, x)[1](v)[0]
        return _vjp

    # First order
    # Separate cause only one output
    order_vjp = vjp(f, jnp.ones((x.shape[0], 1)))
    df = [order_vjp(x)]
    
    # Higher order
    v = jax.ops.index_update(jnp.zeros_like(x), jax.ops.index[:, prop_idx], 1)
    for _ in np.arange(order - 1):
        order_vjp = vjp(order_vjp, v)
        df.append(order_vjp(x))
    return jnp.stack(df, axis=-1)

In [265]:
df = nth_grad(f, x, 5, 1)

In [267]:
df.shape

(10, 2, 5)

In [272]:
df[:, 1, [0]]

DeviceArray([[-0.22070852],
             [ 0.8198199 ],
             [-0.05468073],
             [ 0.01172323],
             [-0.17955527],
             [-0.39451557],
             [ 0.55520076],
             [ 0.20403846],
             [-0.8185504 ],
             [-0.40385708]], dtype=float32)

In [271]:
- 2 * x * f(x)

DeviceArray([[ 0.8283536 , -0.22070852],
             [ 0.15463774,  0.8198199 ],
             [ 0.00879653, -0.05468073],
             [-0.06637478,  0.01172323],
             [ 0.73979336, -0.17955527],
             [ 0.48355308, -0.39451557],
             [-0.21793498,  0.55520076],
             [-0.09029141,  0.20403846],
             [ 0.05831185, -0.8185504 ],
             [ 0.65053004, -0.40385708]], dtype=float32)

In [273]:
df[:, 1, [1]]

DeviceArray([[-1.1009936 ],
             [ 0.32658648],
             [ 0.19771393],
             [-0.02505486],
             [-1.5499763 ],
             [-1.5885731 ],
             [ 0.6975991 ],
             [ 0.47565442],
             [ 0.45216346],
             [-1.3764659 ]], dtype=float32)

In [195]:
df[2](x)

DeviceArray([[-1.5414399 ,  1.293539  ],
             [ 0.10024831, -2.747808  ],
             [ 0.13050836, -0.592539  ],
             [ 0.09985851, -0.0645301 ],
             [-1.4418739 ,  1.0681783 ],
             [-0.87046313,  2.2882454 ],
             [-0.5957793 , -0.703024  ],
             [-0.6310946 ,  0.6099791 ],
             [ 0.0553108 ,  2.4977784 ],
             [-1.1316073 ,  2.317944  ]], dtype=float32)

In [274]:
-2 * f(x) + 4 * x**2 * f(x)

DeviceArray([[-0.02359307, -1.1009936 ],
             [-0.9600828 ,  0.32658648],
             [-0.02084618,  0.19771393],
             [ 0.2312356 , -0.02505486],
             [-0.90232   , -1.5499763 ],
             [-1.49998   , -1.5885731 ],
             [-0.32423532,  0.69759923],
             [-0.01630688,  0.47565442],
             [-0.9462608 ,  0.45216346],
             [-1.0477774 , -1.3764659 ]], dtype=float32)

In [196]:
df = [lambda x: jax.vjp(f, x)]

In [201]:
df[0](x)[1](v_first)

(DeviceArray([[ 0.8283536 , -0.22070852],
              [ 0.15463774,  0.8198199 ],
              [ 0.00879653, -0.05468073],
              [-0.06637478,  0.01172323],
              [ 0.73979336, -0.17955527],
              [ 0.48355308, -0.39451557],
              [-0.21793498,  0.55520076],
              [-0.09029141,  0.20403846],
              [ 0.05831185, -0.8185504 ],
              [ 0.65053004, -0.40385708]], dtype=float32),)

In [214]:
d2f = lambda x: jax.vjp(lambda y: df[-1](y)[1][0], x)

In [216]:
d2f(x)

DeviceArray(-0.84346145, dtype=float32)

In [277]:
for _ in np.arange(0):
    print('aap')

In [None]:

for order in jnp.arange(3):
    df.append(vjp(df[order], v_higher))

In [286]:
@partial(jax.jit, static_argnums=(0, 2))
def nth_deriv_backward(f: Callable, x: jnp.ndarray, order: int, prop_idx: int):
    """ Calculates gradient up to n-th order of f w.r.t input_idx column of x.
    prop_column is the propagated column, meaning for higher order stuff the deriv is 
    [dx^n, dx^(n-1)dym dx^(n-1)dz, ...], Returns tensor with shapes [n_samples, n_inputs, order]"""

    assert order > 0, "Order needs to be positive integer of 1 or higher."

    def vjp(f, v):
        def _vjp(x):
            return jax.vjp(f, x)[1](v)[0]

        return _vjp

    # First order
    # Separate cause only one output
    order_vjp = vjp(f, jnp.ones((x.shape[0], 1)))
    df = [order_vjp(x)]

    # Higher order
    v = jax.ops.index_update(jnp.zeros_like(x), jax.ops.index[:, prop_idx], 1)
    for _ in np.arange(order - 1):
        order_vjp = vjp(order_vjp, v)
        df.append(order_vjp(x))
    return jnp.stack(df, axis=-1)

In [284]:
from typing import Callable

In [292]:
# triggering jit
nth_deriv_backward(f, x, 5, 1);

In [293]:
%%timeit
nth_deriv_backward(f, x, 5, 1)

58.7 µs ± 475 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [294]:
from jax.experimental.jet import jet

In [299]:
jet(f, x, (v_higher, ))

AssertionError: 

In [300]:
v_higher

DeviceArray([[0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.],
             [0., 1.]], dtype=float32)

# Testing

In [19]:
from jax import random, numpy as jnp
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models.networks import MLP
from modax.layers.feature_generators import library_backward, library_backward_new
from modax.layers.feature_generators.utils import nth_deriv_backward

import jax

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


In [21]:
model = MLP([30, 30, 30, 1])
params = model.init(key, X)

In [23]:
f = lambda x: model.apply(params, x)
lib_old = jax.jit(library_backward, static_argnums=(0, ))

In [28]:
prediction, dt_baseline, theta_baseline = lib_old(f, X)

In [30]:
derivs = nth_deriv_backward(f, X, 3, 1)

In [46]:
dt = derivs[:, 0, [0]]

In [48]:
jnp.max(jnp.abs(dt - dt_baseline))

DeviceArray(7.4505806e-08, dtype=float32)

In [51]:
derivs[:, 1, :]

DeviceArray([[ 0.03263867,  0.01787903, -0.01021587],
             [ 0.03505784,  0.01575143, -0.02016297],
             [ 0.03705848,  0.01192245, -0.03423945],
             ...,
             [-0.0612675 ,  0.15675765, -0.2661123 ],
             [-0.04176133,  0.11529644, -0.30945703],
             [-0.02852081,  0.06973058, -0.3238501 ]], dtype=float32)

In [52]:
theta_baseline[:, 1:4]

DeviceArray([[ 0.03263867,  0.01787903, -0.01021587],
             [ 0.03505784,  0.01575143, -0.02016297],
             [ 0.03705848,  0.01192245, -0.03423945],
             ...,
             [-0.06126748,  0.15675765, -0.2661123 ],
             [-0.04176133,  0.11529644, -0.30945703],
             [-0.02852081,  0.06973058, -0.3238501 ]], dtype=float32)

In [56]:
lib_new = jax.jit(library_backward_new(3, 2), static_argnums=(0, ))
#triggering jit
new = lib_new(f, X)[1][1]

In [57]:
new[:, 1:4]

DeviceArray([[ 0.03263867,  0.01787903, -0.01021587],
             [ 0.03505784,  0.01575143, -0.02016297],
             [ 0.03705848,  0.01192245, -0.03423945],
             ...,
             [-0.0612675 ,  0.15675765, -0.2661123 ],
             [-0.04176133,  0.11529644, -0.30945703],
             [-0.02852081,  0.06973058, -0.3238501 ]], dtype=float32)

In [59]:
jnp.max(jnp.abs(theta_baseline - new))

DeviceArray(2.3841858e-07, dtype=float32)



# Comparing speeds

In [1]:
from jax import random, numpy as jnp
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models.networks import MLP
from modax.layers.feature_generators import library_backward, library_backward_new
from modax.layers.feature_generators.utils import nth_deriv_backward

import jax

%load_ext autoreload
%autoreload 2

In [2]:
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


In [3]:
model = MLP([30, 30, 30, 1])
params = model.init(key, X)

In [4]:
f = lambda x: model.apply(params, x)

In [5]:
lib_old = jax.jit(library_backward, static_argnums=(0, ))

In [6]:
# triggering jit
lib_old(f, X);

In [7]:
%%timeit 
lib_old(f, X)

200 µs ± 100 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
baseline = lib_old(f, X)

In [9]:
lib_new = jax.jit(library_backward_new(3, 2), static_argnums=(0, ))

In [10]:
#triggering jit
new = lib_new(f, X)

In [12]:
%%timeit 
lib_new(f, X)

149 µs ± 118 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [13]:
baseline[1][1] - new[1][1]

DeviceArray([[ 0.0000000e+00, -3.7252903e-09,  0.0000000e+00, ...,
              -1.8626451e-09,  0.0000000e+00,  0.0000000e+00],
             [ 0.0000000e+00, -3.7252903e-09,  0.0000000e+00, ...,
              -1.8626451e-09,  0.0000000e+00,  0.0000000e+00],
             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             ...,
             [ 0.0000000e+00,  7.4505806e-09,  0.0000000e+00, ...,
               3.7252903e-09,  0.0000000e+00,  0.0000000e+00],
             [ 0.0000000e+00,  7.4505806e-09,  0.0000000e+00, ...,
               3.7252903e-09,  0.0000000e+00,  0.0000000e+00],
             [ 0.0000000e+00, -3.7252903e-09,  0.0000000e+00, ...,
              -1.8626451e-09,  0.0000000e+00,  0.0000000e+00]],            dtype=float32)

In [14]:
lib_new = jax.jit(library_backward_new(5, 5), static_argnums=(0, ))

In [15]:
#triggering jit
new = lib_new(f, X)

In [16]:
%%timeit 
lib_new(f, X)

291 µs ± 284 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [19]:
lib_new(f, X)[1][1].shape

(1000, 36)

# Tetsing deepmod