# Pylops - basic linear operators with JAX

### Author: M.Ravasi

In this notebook I will test some of the basic linear operators implemented in *Pylops* with JAX backend

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import pylops
import jax.numpy as jnp
import jax

from pylops.utils import dottest
from pylops.utils.wavelets import *
from pylops import LinearOperator
from pylops import JaxOperator
from pylops.basicoperators import *
from pylops.signalprocessing import *
from pylops.optimization.basic import cgls, lsqr
from pylops.optimization.leastsquares import *

os.environ["JAX_PYLOPS"] = '1'

## Matrix Multiplication

In [2]:
ny, nx = 4, 4

G = np.random.normal(0, 1, (ny,nx)).astype('float32')
Gop = MatrixMult(G, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Gop @ x

Gopjax = JaxOperator(MatrixMult(jnp.array(G), dtype='float32'))
xjnp = jnp.ones(nx, dtype='float32')
yjnp = (Gopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Gopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

%timeit -n 10 -r 2 Gop @ x
%timeit -n 10 -r 2 (Gopjax @ xjnp).block_until_ready()

I0000 00:00:1719954028.000026       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-0.31497400999069214 - u^H(Op^Hv)=-0.3149734139442444
7.51 µs ± 2.45 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
36.8 µs ± 1.18 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [3]:
# Inversion with cgls
xcgls = cgls(Gop, y, x0=np.zeros(nx), 
             niter=100, tol=1e-10, show=True)[0]

xcglsjnp = cgls(Gopjax, yjnp, x0=jnp.zeros(nx), 
                niter=100, tol=1e-10, show=True)[0]

xcgls, xcglsjnp

CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00	tol = 1.000000e-10	niter = 100
-----------------------------------------------------------------

    Itn          x[0]              r1norm         r2norm
     1       -7.0939e-02         7.3634e-01     7.3634e-01
     2       -1.8172e-03         1.9142e-01     1.9142e-01
     3        7.4907e-01         1.9685e-02     1.9685e-02
     4        1.0000e+00         7.0531e-11     7.0531e-11

Iterations = 4        Total time (s) = 0.00
-----------------------------------------------------------------

CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00	tol = 1.000000e-10	niter = 100
-----------------------------------------------------------------

    Itn          x[0]              r1norm         r2norm
     1       -7.0939e-02         7.3634e-01     7.3634e-01
     2       -1.8172e-03 

(array([1.0000004 , 1.00000088, 0.99999894, 1.00000039]),
 Array([0.9999997, 1.0000011, 0.9999986, 1.0000007], dtype=float32))

In [4]:
# Inversion with lsqr
xcgls = lsqr(Gop, y, x0=np.zeros(nx), 
             niter=100, atol=1e-10, show=True)[0]

xcglsjnp = lsqr(Gopjax, yjnp, x0=jnp.zeros(nx), 
                niter=100, atol=1e-10, show=True)[0]

xcgls, xcglsjnp

LSQR
------------------------------------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.00000000000000e+00     calc_var =      1
atol = 1.00e-10                 conlim = 1.00e+08
btol = 1.00e-08                 niter =      100
------------------------------------------------------------------------------------------
   Itn     x[0]      r1norm     r2norm   Compatible   LS     Norm A   Cond A
     0  0.0000e+00  4.456e+00  4.456e+00   1.0e+00  5.9e-01
     1 -7.0939e-02  7.363e-01  7.363e-01   1.7e-01  1.5e-01  2.7e+00  1.0e+00
     2 -1.8172e-03  1.914e-01  1.914e-01   4.3e-02  3.6e-03  3.7e+00  2.0e+00
     3  7.4907e-01  1.968e-02  1.968e-02   4.4e-03  5.7e-05  3.7e+00  1.7e+01
     4  1.0000e+00  1.714e-10  1.714e-10   3.8e-11  4.0e-11  3.7e+00  1.1e+02
 
LSQR finished, Op x - b is small enough, given atol, btol                 
 
istop =       1   r1norm = 1.7e-10   anorm = 3.7e+00   arnorm = 4.8e-10
itn   =       4   r2norm =

(array([1.0000004 , 1.00000088, 0.99999894, 1.00000039]),
 Array([1.0000026 , 0.99999934, 1.0000015 , 0.9999989 ], dtype=float32))

In [5]:
# Using AD to implement rmatvec
xxjnp = Gopjax.H @ yjnp
xxjnpad = Gopjax.rmatvecad(xjnp, yjnp)

xxjnp - xxjnpad

Array([0., 0., 0., 0.], dtype=float32)

In [6]:
## Automatic vectorization
auto_batch_matvec = jax.vmap(Gopjax._matvec)

xs = jnp.stack([xjnp, xjnp])
ys = auto_batch_matvec(xs)

ys[0], yjnp

(Array([ 0.29720128, -4.3841143 ,  0.408643  ,  0.61299235], dtype=float32),
 Array([ 0.29720122, -4.3841143 ,  0.40864307,  0.61299235], dtype=float32))

In [7]:
## Automatic differentiation
def fun(x):
    #y = Gopjax @ x
    y = Gopjax(x)
    loss = jnp.sum(y) 
    return loss

jax.grad(fun)(xjnp)

Array([-0.24497196,  0.52037215, -0.6736099 , -2.667068  ], dtype=float32)

In [8]:
# JIT of _matvec and _rmatvec
Gopjax_matvec = jax.jit(Gopjax._matvec)
_ = Gopjax_matvec(xjnp).block_until_ready()
Gopjax_rmatvec = jax.jit(Gopjax._rmatvec)
_ = Gopjax_rmatvec(xjnp).block_until_ready()

%timeit -n 50 -r 2 (Gopjax @ xjnp).block_until_ready()
%timeit -n 50 -r 2 Gopjax_matvec(xjnp).block_until_ready()

%timeit -n 50 -r 2 (Gopjax.H @ yjnp).block_until_ready()
%timeit -n 50 -r 2 Gopjax_rmatvec(yjnp).block_until_ready()

33.8 µs ± 1.6 µs per loop (mean ± std. dev. of 2 runs, 50 loops each)
8.15 µs ± 2.01 µs per loop (mean ± std. dev. of 2 runs, 50 loops each)
56 µs ± 2.08 µs per loop (mean ± std. dev. of 2 runs, 50 loops each)
6.14 µs ± 131 ns per loop (mean ± std. dev. of 2 runs, 50 loops each)


## Identity

In [9]:
ny, nx = 5, 5 

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()
y1jnp = (Iopjax.H @ yjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))
print('y1  = ', y1jnp, type(y1jnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=1.1268846988677979 - u^H(Op^Hv)=1.1268846988677979
x  =  [0. 1. 2. 3. 4.]
y  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
y1  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.01 µs ± 1.34 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.9 µs ± 13.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [10]:
ny, nx = 5, 7

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.03560042381286621 - u^H(Op^Hv)=0.03560042381286621
x  =  [0. 1. 2. 3. 4. 5. 6.]
y  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.05 µs ± 1.26 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
46.2 µs ± 13.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [11]:
ny, nx = 7, 5

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-2.5060155391693115 - u^H(Op^Hv)=-2.5060155391693115
x  =  [0. 1. 2. 3. 4.]
y  =  [0. 1. 2. 3. 4. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
8.12 µs ± 1.43 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.5 µs ± 13 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Diagonal

In [12]:
nx = 10

d = np.ones(nx, dtype='float32')
Dop = Diagonal(d, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Dop @ x

djnp = jnp.ones(nx, dtype='float32')
Dopjax = JaxOperator(Diagonal(djnp, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Dopjax @ xjnp).block_until_ready()

dottest(Dopjax, nx, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Dop @ x
%timeit -n 10 -r 2 (Dopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=1.083421230316162 - u^H(Op^Hv)=1.083421230316162
y= [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.65 µs ± 1.53 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
64.8 µs ± 30.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [13]:
# Complex numbers 
nx = 10

d = np.ones(nx, dtype='float32') + 1j*np.ones(nx, dtype='float32')
Dop = Diagonal(d, dtype='float32')
x = np.ones(nx, dtype='float32')+ 1j*np.ones(nx, dtype='float32')
y = Dop @ x

djnp = jnp.ones(nx, dtype='float32')
Dopjax = JaxOperator(Diagonal(djnp, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Dopjax @ xjnp).block_until_ready()

dottest(Dopjax, nx, nx, complexflag=2, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Dop @ x
%timeit -n 10 -r 2 (Dopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=(-3.155590772628784+1.8626459836959839j) - u^H(Op^Hv)=(-3.155590772628784+1.8626459836959839j)
y= [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.52 µs ± 1.41 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
47.2 µs ± 14.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Zero

In [14]:
ny, nx = 5, 5

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
5.71 µs ± 1.07 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
44.6 µs ± 13.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [15]:
ny, nx = 7, 5

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
5.67 µs ± 1.31 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.5 µs ± 14.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [16]:
ny, nx = 5, 7

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
5.5 µs ± 1.16 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
43 µs ± 12.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Sum

In [17]:
ny, nx = 5, 7

Sop = Sum(dims=(ny, nx), axis=0)
x   = (np.arange(ny*nx)).reshape(ny, nx)
y = Sop @ x

Sopjax = JaxOperator(Sum(dims=(ny, nx), axis=0, dtype='float32'))
xjnp = (jnp.arange(ny*nx)).reshape(ny, nx)
yjnp = (Sopjax @ xjnp).block_until_ready()

dottest(Sopjax, nx, ny*nx, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-4.44984769821167 - u^H(Op^Hv)=-4.449847221374512
y= <class 'jaxlib.xla_extension.ArrayImpl'>
9.67 µs ± 2.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
57.5 µs ± 21.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Transpose

In [18]:
ny, nx = 20, 40
dims = (ny, nx)

Top = Transpose(dims=dims, axes=(1,0))
x = np.arange(ny*nx).reshape(dims)
y = Top @ x

Topjax = JaxOperator(Transpose(dims=dims, axes=(1,0), dtype='float32'))
xjnp = (jnp.arange(ny*nx)).reshape(dims)
yjnp = (Topjax @ xjnp).block_until_ready()

dottest(Topjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Top @ x
%timeit -n 10 -r 2 (Topjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=6.494052886962891 - u^H(Op^Hv)=6.494050025939941
y= <class 'jaxlib.xla_extension.ArrayImpl'>
The slowest run took 5.07 times longer than the fastest. This could mean that an intermediate result is being cached.
18 µs ± 12 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
71.1 µs ± 9.67 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Flip

In [19]:
nt = 10

Fop = Flip(nt)
x = np.arange(nt)
y = Fop @ x

Fopjax = JaxOperator(Flip(nt, dtype='float32'))
xjnp = jnp.arange(nt)
yjnp = (Fopjax @ xjnp).block_until_ready()

dottest(Fopjax, nt, nt, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-0.17199133336544037 - u^H(Op^Hv)=-0.17199134826660156
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.1 µs ± 2.22 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.1 µs ± 12.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [20]:
nt, nx = 10, 5
x = np.outer(np.arange(nt), np.ones(nx))

for axis in (0, 1):
    Fop = Flip(dims=(nt, nx), axis=axis)
    y = Fop @ x

    Fopjax = JaxOperator(Flip(dims=(nt, nx), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.arange(nt), jnp.ones(nx))
    yjnp = (Fopjax @ xjnp).block_until_ready()

    dottest(Fopjax, nt*nx, nt*nx, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Fop @ x
    %timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=4.9428839683532715 - u^H(Op^Hv)=4.942883491516113
y= <class 'jaxlib.xla_extension.ArrayImpl'>
The slowest run took 4.08 times longer than the fastest. This could mean that an intermediate result is being cached.
20.4 µs ± 12.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
66.6 µs ± 7.01 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=-1.4609696865081787 - u^H(Op^Hv)=-1.4609711170196533
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10 µs ± 1.94 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
65.4 µs ± 5.47 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [21]:
nt, nx, ny = 2, 3, 4
x = np.outer(np.ones(nt), np.ones(nx))[:, :, np.newaxis] * np.arange(ny)

for axis in (0, 1, 2):
    Fop = Flip(dims=(nt, nx, ny), axis=axis)
    y = Fop @ x

    Fopjax = JaxOperator(Flip(dims=(nt, nx, ny), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.ones(nt), jnp.ones(nx))[:, :, jnp.newaxis] * jnp.arange(ny)
    yjnp = (Fopjax @ xjnp).block_until_ready()

    dottest(Fopjax, nt*nx*ny, nt*nx*ny, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Fop @ x
    %timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-6.490195274353027 - u^H(Op^Hv)=-6.490195274353027
y= <class 'jaxlib.xla_extension.ArrayImpl'>
12.1 µs ± 2.57 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
80.1 µs ± 16 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=-5.266023635864258 - u^H(Op^Hv)=-5.266023635864258
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.4 µs ± 2.15 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
68.3 µs ± 4.25 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=6.5485663414001465 - u^H(Op^Hv)=6.5485663414001465
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.5 µs ± 2.39 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
68.6 µs ± 6.21 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Roll

In [22]:
nt, nx = 10, 5
x = np.outer(np.arange(nt), np.ones(nx))


for axis in (0, 1):
    
    Rop = Roll((nt, nx), axis=axis)
    y = Rop @ x

    Ropjax = JaxOperator(Roll((nt, nx), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.arange(nt), jnp.ones(nx))
    yjnp = (Ropjax @ xjnp).block_until_ready()

    dottest(Ropjax, nt*nx, nt*nx, backend='jax', verb=True, atol=1e-3)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Rop @ x
    %timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-4.097294330596924 - u^H(Op^Hv)=-4.097293853759766
y= <class 'jaxlib.xla_extension.ArrayImpl'>
27.2 µs ± 3.17 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
79.5 µs ± 17.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=4.786231994628906 - u^H(Op^Hv)=4.786231994628906
y= <class 'jaxlib.xla_extension.ArrayImpl'>
20.7 µs ± 2.76 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
97.3 µs ± 32.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [23]:
nt = 10

Rop = Roll(nt, shift=-5)
x = np.arange(nt)
y = Rop @ x

Ropjax = JaxOperator(Roll(nt, shift=-5, dtype='float32'))
xjnp = jnp.arange(nt)
yjnp = (Ropjax @ xjnp).block_until_ready()

dottest(Ropjax, nt, nt, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Rop @ x
%timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-1.3992513418197632 - u^H(Op^Hv)=-1.399251103401184
y= <class 'jaxlib.xla_extension.ArrayImpl'>
18.2 µs ± 4.16 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.7 µs ± 14.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Pad

In [24]:
dims = 10
pad = (10, 3)

Pop = Pad(dims, pad)
x = np.arange(dims)+1.
y = Pop @ x

Popjax = JaxOperator(Pad(dims, pad, dtype='float32'))
xjnp = jnp.arange(dims)+1.
yjnp = (Popjax @ xjnp).block_until_ready()

dottest(Popjax, dims+pad[0]+pad[1], dims, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-3.463016986846924 - u^H(Op^Hv)=-3.463016986846924
y= <class 'jaxlib.xla_extension.ArrayImpl'>
25.5 µs ± 4.45 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45 µs ± 13.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [25]:
dims = (5, 4)
pad = ((1, 0), (3, 4))

Pop = Pad(dims, pad)
x = (np.arange(np.prod(np.array(dims)))+1.).reshape(dims)
y = Pop @ x

Popjax = JaxOperator(Pad(dims, pad, dtype='float32'))
xjnp = (jnp.arange(jnp.prod(jnp.array(dims)))+1.).reshape(dims)
yjnp = (Popjax @ xjnp).block_until_ready()

dottest(Popjax, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.9497206807136536 - u^H(Op^Hv)=0.9497206807136536
y= <class 'jaxlib.xla_extension.ArrayImpl'>
42.4 µs ± 7.47 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
77.2 µs ± 17 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Symmetrize

In [26]:
nt = 10

Sop = Symmetrize(nt)
x = np.arange(nt)+1
y = Sop @ x

Sopjax = JaxOperator(Symmetrize(nt, dtype='float32'))
xjnp = jnp.arange(nt)+1
yjnp = (Sopjax @ xjnp).block_until_ready()

dottest(Sopjax, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=1.8043608665466309 - u^H(Op^Hv)=1.8043606281280518
y= <class 'jaxlib.xla_extension.ArrayImpl'>
The slowest run took 4.30 times longer than the fastest. This could mean that an intermediate result is being cached.
20.6 µs ± 12.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
41.1 µs ± 5.75 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Restriction

In [27]:
N=200

perc_subsampling=0.4
Nsub = int(np.round(N*perc_subsampling))
iava = np.sort(np.random.permutation(np.arange(N))[:Nsub])

for inplace in (True, False):
    Rop = Restriction(N, iava, inplace=inplace, dtype='float64')
    x = np.zeros(N)
    y = Rop @ x

    Ropjax = JaxOperator(Restriction(N, iava, inplace=inplace, dtype='float32'))
    xjnp = jnp.zeros(N)
    yjnp = (Ropjax @ xjnp).block_until_ready()

    dottest(Ropjax, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Rop @ x
    %timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-8.079109191894531 - u^H(Op^Hv)=-8.079109191894531
y= <class 'jaxlib.xla_extension.ArrayImpl'>
11 µs ± 3.11 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
76.4 µs ± 20.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=6.646598815917969 - u^H(Op^Hv)=6.646598815917969
y= <class 'jaxlib.xla_extension.ArrayImpl'>
9.78 µs ± 1.84 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
34.2 µs ± 2.21 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Regression

In [28]:
# parameters
N = 30
t = np.arange(N, dtype='float64')

LRop = LinearRegression(jnp.array(t), dtype=None)
x = np.array([1., 2.])
y = LRop @ x

LRopjax = JaxOperator(LinearRegression(t, dtype=None))
xjnp = jnp.array(x)
yjnp = (LRopjax @ xjnp).block_until_ready()

dottest(LRopjax, N, 2, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 LRop @ x
%timeit -n 10 -r 2 (LRopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-25.029476165771484 - u^H(Op^Hv)=-25.02947235107422
y= <class 'jaxlib.xla_extension.ArrayImpl'>
238 µs ± 37.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
34.5 µs ± 2.91 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


  u = randn(nc, backend).astype(rdtype)
  v = randn(nr, backend).astype(rdtype)


In [29]:
# CGLS solver
xcglsjnp = cgls(LRopjax, yjnp, jnp.zeros(2), damp=1e-10, niter=10 ,show=0)[0]

print('cgls solution xlsqr=', xcglsjnp)

cgls solution xlsqr= [0.99999994 2.        ]


## First Derivative

In [30]:
nx = 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        for order in (3, 5):
            print(kind, edge, order)
            D1op = FirstDerivative(nx, edge=edge, kind=kind, 
                                   order=order if kind == 'centered' else 3, 
                                   dtype='float32')
            x = np.ones(nx)
            y = D1op @ x

            D1opjax = JaxOperator(FirstDerivative(nx, edge=edge, kind=kind, 
                                                  order=order if kind == 'centered' else 3, 
                                                  dtype='float32'))
            xjnp = jnp.array(x)
            yjnp = (D1opjax @ xjnp).block_until_ready()

            dottest(D1opjax, nx, nx, backend='jax', verb=True, atol=1e-3)
            print('y=', type(yjnp))

            %timeit -n 10 -r 2 D1op @ x
            %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 3
Dot test passed, v^H(Opu)=4.77616548538208 - u^H(Op^Hv)=4.7761664390563965
y= <class 'jaxlib.xla_extension.ArrayImpl'>
11.4 µs ± 1.92 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
34.1 µs ± 2.95 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward False 5
Dot test passed, v^H(Opu)=2.7915396690368652 - u^H(Op^Hv)=2.7915399074554443
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.8 µs ± 2.14 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
33.2 µs ± 2.15 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 3
Dot test passed, v^H(Opu)=0.9124404191970825 - u^H(Op^Hv)=0.9124410152435303
y= <class 'jaxlib.xla_extension.ArrayImpl'>
The slowest run took 4.22 times longer than the fastest. This could mean that an intermediate result is being cached.
24.3 µs ± 15 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
36.2 µs ± 4.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=6.94415235

In [31]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = FirstDirectionalDerivative((ny, nx), v=0.5 * np.ones(2), edge=edge, kind=kind, 
                                          dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(FirstDirectionalDerivative((ny, nx), v=0.5 * np.ones(2), edge=edge, kind=kind, 
                                                         dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=15.370150566101074 - u^H(Op^Hv)=15.370157241821289
y= <class 'jaxlib.xla_extension.ArrayImpl'>
67.1 µs ± 22.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45 µs ± 20.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=31.837804794311523 - u^H(Op^Hv)=31.837799072265625
y= <class 'jaxlib.xla_extension.ArrayImpl'>
77.1 µs ± 7.94 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
32 µs ± 9.37 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=-0.30357295274734497 - u^H(Op^Hv)=-0.3035726249217987
y= <class 'jaxlib.xla_extension.ArrayImpl'>
65.4 µs ± 12.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
36.6 µs ± 13.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=-3.91501522064209 - u^H(Op^Hv)=-3.9150142669677734
y= <class 'jaxlib.xla_extension.ArrayImpl'>
78.8 µs ± 28.6 µs per loop 

## Second Derivative

In [32]:
nx = 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = SecondDerivative(nx, edge=edge, kind=kind, 
                                dtype='float32')
        x = np.ones(nx)
        y = D1op @ x

        D1opjax = JaxOperator(SecondDerivative(nx, edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, nx, nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=-6.706727504730225 - u^H(Op^Hv)=-6.706727027893066
y= <class 'jaxlib.xla_extension.ArrayImpl'>
24 µs ± 10.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
37.5 µs ± 5.14 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=11.86607551574707 - u^H(Op^Hv)=11.86607551574707
y= <class 'jaxlib.xla_extension.ArrayImpl'>
19.4 µs ± 5.65 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
50.5 µs ± 17.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=-11.492124557495117 - u^H(Op^Hv)=-11.4921236038208
y= <class 'jaxlib.xla_extension.ArrayImpl'>
14.7 µs ± 2.67 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
32.8 µs ± 2.02 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=8.562893867492676 - u^H(Op^Hv)=8.562893867492676
y= <class 'jaxlib.xla_extension.ArrayImpl'>
26.7 µs ± 3.55 µs per loop (mean

## Laplacian

In [33]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = Laplacian((ny, nx), edge=edge, kind=kind, dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(Laplacian((ny, nx), edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=-33.99725341796875 - u^H(Op^Hv)=-33.99726867675781
y= <class 'jaxlib.xla_extension.ArrayImpl'>
48.3 µs ± 4.71 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
42.5 µs ± 21.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=67.1269760131836 - u^H(Op^Hv)=67.12693786621094
y= <class 'jaxlib.xla_extension.ArrayImpl'>
66.2 µs ± 19.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
24.8 µs ± 3.22 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=-136.33056640625 - u^H(Op^Hv)=-136.33055114746094
y= <class 'jaxlib.xla_extension.ArrayImpl'>
67.8 µs ± 15.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
23.1 µs ± 2.91 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=-72.66870880126953 - u^H(Op^Hv)=-72.66866302490234
y= <class 'jaxlib.xla_extension.ArrayImpl'>
68.4 µs ± 16.3 µs per loop (me

## Gradient

In [34]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = Gradient((ny, nx), edge=edge, kind=kind, dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(Gradient((ny, nx), edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, 2*ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=12.33971118927002 - u^H(Op^Hv)=12.339707374572754
y= <class 'jaxlib.xla_extension.ArrayImpl'>
36 µs ± 4.04 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
35 µs ± 13.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=30.52020835876465 - u^H(Op^Hv)=30.520204544067383
y= <class 'jaxlib.xla_extension.ArrayImpl'>
44.7 µs ± 18.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
42.3 µs ± 6.29 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=11.551376342773438 - u^H(Op^Hv)=11.551379203796387
y= <class 'jaxlib.xla_extension.ArrayImpl'>
40.8 µs ± 4.45 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
34.1 µs ± 12.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=27.0750675201416 - u^H(Op^Hv)=27.07506561279297
y= <class 'jaxlib.xla_extension.ArrayImpl'>
50.7 µs ± 14.6 µs per loop (mean ± 

## Causal integration

In [35]:
nx = 1000000

Cop = CausalIntegration(nx, dtype='float32')

x = np.ones(nx, dtype='float32')
y = Cop @ x

Copjax = JaxOperator(Cop)

xjnp = jnp.ones(nx)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
3.72 ms ± 58.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
2.89 ms ± 144 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## VStack

In [36]:
D2hop = SecondDerivative(dims=[11,21], axis=1, dtype='float64')
D2vop = SecondDerivative(dims=[11,21], axis=0, dtype='float64')

D1hopjax = JaxOperator(FirstDerivative(dims=[11,21], axis=1, dtype='float32'))
D1vopjax = JaxOperator(FirstDerivative(dims=[11,21], axis=0, dtype='float32'))
    
Vstackop = VStack([D1hopjax,D1vopjax])
xjnp = jnp.ones(11*21)
yjnp = (Vstackop @ xjnp).block_until_ready()

dottest(Vstackop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=-9.191097259521484 - u^H(Op^Hv)=-9.191096305847168
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## HStack

In [37]:
Hstackop = HStack([D1hopjax,D1vopjax])
xjnp = jnp.ones(2*11*21)
yjnp = (Hstackop @ xjnp).block_until_ready()

dottest(Hstackop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=6.765942573547363 - u^H(Op^Hv)=6.765944957733154
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## BlockDiag

In [38]:
Blockop = BlockDiag([D1hopjax,D1vopjax])
xjnp = jnp.ones(2*11*21)
yjnp = (Blockop @ xjnp).block_until_ready()

dottest(Blockop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=8.74829387664795 - u^H(Op^Hv)=8.748291969299316
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## FFT

In [39]:
nx = 100000

Fop = FFT(nx, real=True, dtype='float32')

x = np.ones(nx, dtype='float32')
y = Fop @ x
xadj = Fop.H @ y
dottest(Fop, complexflag=1, verb=True, atol=1e-3)

Fopjax = JaxOperator(Fop)

xjnp = jnp.ones(nx, dtype='float32')
yjnp = (Fopjax @ xjnp).block_until_ready()
xadjjnp = (Fopjax.H @ yjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Fopjax, backend='jax', complexflag=1, verb=True, atol=1e-2)

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=38.648006439208984 - u^H(Op^Hv)=38.648101806640625
y= <class 'jaxlib.xla_extension.ArrayImpl'>




Dot test passed, v^H(Opu)=136.31239318847656 - u^H(Op^Hv)=136.31370544433594
952 µs ± 21.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.3 ms ± 44.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## FFT2D

In [40]:
nx, nt = 200, 100

Fop = FFT2D((nx, nt), nffts=(4*nx, 4*nt), dtype='float32')
x = np.ones(nx*nt, dtype='float32')
y = Fop @ x
dottest(Fop, complexflag=1, verb=True, atol=1e-3)

Fopjax = JaxOperator(Fop)

xjnp = jnp.ones(nx*nt, dtype='float32')
yjnp = (Fopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Fopjax, backend='jax', complexflag=1, verb=True, atol=1e-2)

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=133.7815704345703 - u^H(Op^Hv)=133.78172302246094
y= <class 'jaxlib.xla_extension.ArrayImpl'>




Dot test passed, v^H(Opu)=46.66827392578125 - u^H(Op^Hv)=46.669124603271484
3.42 ms ± 274 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.74 ms ± 66.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## FFTND

In [41]:
ny, nx, nt = 10, 200, 100

Fop = FFTND((ny, nx, nt), engine='numpy', dtype='float32')

x = np.ones(ny*nx*nt, dtype='float32')
y = Fop @ x

Fopjax = JaxOperator(Fop)
xjnp = jnp.ones(ny*nx*nt, dtype='float32')
yjnp = (Fopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
4 ms ± 18.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.46 ms ± 54 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)




## Convolve1D

In [42]:
nx = 10
offset = 4

h = np.arange(-3, 3)
print(h)

Cop = Convolve1D(nx, h=h, offset=offset, 
                 dtype='float32', 
                 method='direct')
x  = np.zeros(nx)
x[3] = 1
y = Cop @ x

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve1D(nx, h=hjnp, offset=offset, 
                     dtype='float32', 
                     method='direct'))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

#dottest(Copjax, backend='jax', verb=True, atol=1e-3)

#%timeit -n 10 -r 2 Cop @ x
#%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

[-3 -2 -1  0  1  2]
y= <class 'jaxlib.xla_extension.ArrayImpl'>


In [43]:
## JAX BACKEND CURRENTLY NOT WORKING ##
nt = 301
nx = 20
dt = 0.004
t  = np.arange(nt)*dt

h = np.ones(3)
hcenter=1

Cop = Convolve1D(dims=[nt,nx], h=h, offset=hcenter, axis=0, dtype='float32')
x  = np.zeros((nt, nx))
x[int(nt/2),:] = 1
y = Cop @ x

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve1D(dims=[nt,nx], h=hjnp, offset=hcenter, axis=0, dtype='float32'))

xjnp = jnp.array(x).astype(jnp.float32)
#yjnp = (Copjax @ xjnp).block_until_ready()
#print('y=', type(yjnp))

#dottest(Copjax, backend='jax', verb=True, atol=1e-3)

%timeit -n 10 -r 2 Cop @ x
#%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

105 µs ± 9.69 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Convolve2D

In [44]:
nt = 51
nx = 81
dt = 0.004
t  = np.arange(nt)*dt

nh = [11,5]
h  = np.ones((nh[0], nh[1]))

Cop = Convolve2D(dims=[nt,nx], h=h, offset=[int(nh[0])/2,int(nh[1])/2], 
                 dtype='float32')
x  = np.zeros((nt,nx))
x[int(nt/2),int(nx/2)] = 1
y    = Cop*x.flatten()

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve2D(dims=[nt,nx], h=hjnp, 
                                offset=[int(nh[0])/2,int(nh[1])/2], dtype='float32'))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Copjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=533.0953369140625 - u^H(Op^Hv)=533.09326171875
182 µs ± 46.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
158 µs ± 17.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [45]:
nt = 51
nx = 81
ny = 11
dt = 0.004
t  = np.arange(nt)*dt
x  = np.zeros((nt, nx, ny))
x[int(nt/2), int(nx/2), int(ny/2)] = 1

nh = [11,5]
h  = np.ones((nh[0], nh[1]))

Cop = Convolve2D(dims=[nt,nx,ny], h=h, offset=[int(nh[0])/2,int(nh[1])/2], 
                 axes=(0,1), dtype='float32')

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve2D(dims=[nt,nx,ny], h=hjnp, offset=[int(nh[0])/2,int(nh[1])/2], 
                                axes=(0,1), dtype='float32'))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Copjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-1388.2647705078125 - u^H(Op^Hv)=-1388.2679443359375
802 µs ± 95.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.36 ms ± 10.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## ConvolveND

In [46]:
ny, nx, nz = 20, 40, 30
x = np.zeros((ny, nx, nz))
x[ny//3, nx//2, nz//4] = 1
h = np.ones((3, 5, 3))
offset = [1, 2, 1]

Cop = ConvolveND(dims=[ny, nx, nz], h=h, offset=offset, 
                 axes=[0,1,2], dtype='float32')
y    = Cop @ x

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(ConvolveND(dims=[ny, nx, nz], h=hjnp, offset=offset, 
                                axes=[0,1,2], dtype='float32'))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Copjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=617.8408813476562 - u^H(Op^Hv)=617.8402099609375
855 µs ± 11 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
543 µs ± 19.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## NonStationaryConvolve1D

In [47]:
nt = 601
dt = 0.004
t = np.arange(nt)*dt

wav1, _, wav1c = ricker(t[:51], f0=20)
wav2 = ricker(t[:51], f0=50)[0]
wav2 = wav2 * (np.sum(np.abs(wav1)) / np.sum(np.abs(wav2)))
wavs = np.stack([wav1, wav2])

Cop = NonStationaryConvolve1D(dims=nt, hs=wavs, ih=(201, 401))
x = np.zeros(nt)
for ix in range(64, nt-64, 64):
    x[ix] = 1.    
y = Cop @ x
dottest(Cop, verb=True, atol=1e-3)

wavsjnp = jnp.array(wavs).astype(jnp.float32)
Copjax = JaxOperator(NonStationaryConvolve1D(dims=nt, hs=wavs, ih=(201, 401), dtype='float32'))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))
dottest(Copjax, backend='jax', verb=True, atol=1e-3)

%timeit -n 2 -r 2 Cop @ x
%timeit -n 2 -r 2 (Copjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-37.6285863677083 - u^H(Op^Hv)=-37.62858636770827
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-114.16763305664062 - u^H(Op^Hv)=-114.16763305664062
5.09 ms ± 512 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
The slowest run took 6.98 times longer than the fastest. This could mean that an intermediate result is being cached.
223 µs ± 167 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)


In [48]:
C1op = NonStationaryFilters1D(inp=x, hsize=wavs.shape[1],  ih=(201, 401))
y1 = C1op @ wavs
dottest(C1op, verb=True, atol=1e-3)

C1opjax = JaxOperator(NonStationaryFilters1D(inp=x, hsize=wavs.shape[1],  ih=(201, 401), dtype='float32'))
yjnp = (C1opjax @ wavsjnp).block_until_ready()
print('y=', type(yjnp))
dottest(C1opjax, backend='jax', verb=True, atol=1e-3)

%timeit -n 2 -r 2 C1op @ wavs
%timeit -n 2 -r 2 (C1opjax @ wavsjnp).block_until_ready()

Dot test passed, v^H(Opu)=-7.2153320796083005 - u^H(Op^Hv)=-7.215332079608288
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=19.079679489135742 - u^H(Op^Hv)=19.079675674438477
4.56 ms ± 97.6 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
The slowest run took 4.26 times longer than the fastest. This could mean that an intermediate result is being cached.
138 µs ± 85.8 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)


## Interp

In [49]:
n = int(1e2)
iava = np.arange(0, n, 10) + 0.5

for kind in ('nearest', 'linear', 'sinc'):
    Iop = pylops.signalprocessing.Interp(n, iava, kind='nearest')[0]
    x = np.ones(n)
    y = Iop @ x

    Iopjax = JaxOperator(pylops.signalprocessing.Interp(n, iava, kind='nearest', dtype='float32')[0])
    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (Iopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(Iopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 10 -r 2 Iop @ x
    %timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-0.9163461923599243 - u^H(Op^Hv)=-0.9163461923599243
9.54 µs ± 2.19 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
37.2 µs ± 4.09 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-0.8063273429870605 - u^H(Op^Hv)=-0.8063273429870605
9.5 µs ± 1.92 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
36.4 µs ± 3.04 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-5.931110858917236 - u^H(Op^Hv)=-5.931110858917236
10.4 µs ± 3.03 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
65.3 µs ± 31.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [50]:
n = int(1e2)
iava = np.arange(0, n, 10) + 0.5

for kind in ('nearest', 'linear', 'sinc'):
    Iop = pylops.signalprocessing.Interp((10, n), iava, kind='nearest', axis=-1)[0]
    x = np.ones(n*10)
    y = Iop @ x

    Iopjax = JaxOperator(pylops.signalprocessing.Interp((10, n), iava, kind='nearest', axis=-1, dtype='float32')[0])
    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (Iopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(Iopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 10 -r 2 Iop @ x
    %timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=2.492629289627075 - u^H(Op^Hv)=2.492629289627075
9.41 µs ± 1.98 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
25.3 µs ± 3.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-5.612744331359863 - u^H(Op^Hv)=-5.612744331359863
8.92 µs ± 1.86 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
22 µs ± 2.01 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=0.781684160232544 - u^H(Op^Hv)=0.781684160232544
9.06 µs ± 1.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
23.1 µs ± 2.38 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Shift

In [51]:
n = 100

shift = 1.5
Sop = pylops.signalprocessing.Shift(n, shift, sampling=1., real=True, dtype=np.float64)
x = np.ones(n)
y = Sop @ x

Sopjax = JaxOperator(pylops.signalprocessing.Shift(n, shift, sampling=1., real=True, dtype='float32'))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Sopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Sopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-21.916141510009766 - u^H(Op^Hv)=-21.916139602661133
50.5 µs ± 5.97 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
47.9 µs ± 4.83 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [52]:
n = 100

shift = 1.5
Sop = pylops.signalprocessing.Shift((10, n), shift, sampling=1., real=True, dtype=np.float64)
x = np.ones(10*n)
y = Sop @ x

Sopjax = JaxOperator(pylops.signalprocessing.Shift((10, n), shift, sampling=1., real=True, dtype='float32'))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Sopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Sopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=64.58868408203125 - u^H(Op^Hv)=64.5887451171875
83.5 µs ± 4.74 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.1 µs ± 5.63 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Sliding1D

In [53]:
from pylops.signalprocessing.sliding1d import sliding1d_design

dimd = 49
nwin = 26
nover = 3
nwins, dim, mwin_inends, dwin_inends = sliding1d_design(dimd, nwin, nover, nwin)

Op = Identity(nwin, nwin)
Sop = Sliding1D(Op.H, dim, dimd, nwin, nover, tapertype='cosine')
x = np.ones(dim)
y = Sop * x

Opjax = JaxOperator(Identity(nwin, nwin, dtype='float32'))
Sopjax = JaxOperator(Sliding1D(Opjax.H, dim, dimd, nwin, nover, tapertype='cosine'))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Sopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Sopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()



y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-2.105905771255493 - u^H(Op^Hv)=-2.105905771255493
51.4 µs ± 6.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
25.5 µs ± 3.86 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Sliding2D

In [54]:
from pylops.signalprocessing.sliding2d import sliding2d_design

dimsd = (100, 50)
nwin = 26
nover = 4

nwins, dims, mwin_inends, dwin_inends = sliding2d_design(dimsd, nwin, nover, (256, 256))
Op = FFT2D((nwin, dimsd[1]), nffts=(256, 256))
Sop = Sliding2D(Op.H, dims, dimsd, nwin, nover, tapertype=None)
x = np.ones(Sop.dims)
y = Sop @ x

Opjax = JaxOperator(FFT2D((nwin, dimsd[1]), nffts=(256, 256), dtype='float32'))
Sopjax = JaxOperator(Sliding2D(Opjax.H, dims, dimsd, nwin, nover, tapertype=None))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Sopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Sopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()



y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=19.226972579956055 - u^H(Op^Hv)=19.226343154907227
6.14 ms ± 961 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.51 ms ± 22.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Fredholm1

In [55]:
nt, nh1, nh2, nz = 100, 10, 7, 5
h = np.ones((nt, nh1, nh2))

Fop = pylops.signalprocessing.Fredholm1(h, nz)
x = np.ones(Fop.dims)
y = Fop @ x

Fopjax = JaxOperator(pylops.signalprocessing.Fredholm1(h, nz))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Fopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Fopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=55.578914642333984 - u^H(Op^Hv)=55.579063415527344
46.5 µs ± 3.34 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
160 µs ± 11.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Poststack

In [56]:
nt0 = 301
dt0 = 0.004
t0 = np.arange(nt0) * dt0

ntwav = 41
wav, twav, wavc = ricker(t0[: ntwav // 2 + 1], 20)

for explicit in (True, False):
    PPop = pylops.avo.poststack.PoststackLinearModelling(wav / 2, nt0=nt0, explicit=True)
    x = np.zeros(nt0)
    x[nt0//4:] = 1.
    y = PPop @ x

    PPopjax = JaxOperator(pylops.avo.poststack.PoststackLinearModelling(jnp.array(wav).astype(jnp.float32) / 2, 
                                                                        nt0=nt0, explicit=True))
    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (PPopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(PPopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 10 -r 2 PPop @ x
    %timeit -n 10 -r 2 (PPopjax @ xjnp).block_until_ready()



y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-0.3276890814304352 - u^H(Op^Hv)=-0.32768774032592773
17.8 µs ± 4.55 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
46.4 µs ± 4.81 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-7.148036479949951 - u^H(Op^Hv)=-7.14804220199585
The slowest run took 5.04 times longer than the fastest. This could mean that an intermediate result is being cached.
105 µs ± 70.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
47.2 µs ± 5.05 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Prestack

In [57]:
theta = np.arange(0, 30., 5.).astype(np.float32)

for vsvp in (0.5, np.ones(nt0)):
    PPop = pylops.avo.avo.AVOLinearModelling(
        theta, vsvp=vsvp, nt0=nt0, linearization="akirich", dtype=np.float64
    )
    x = np.zeros(3*nt0)
    y = PPop @ x

    PPopjax = JaxOperator(pylops.avo.avo.AVOLinearModelling(jnp.array(theta).astype(jnp.float32), 
                                                            vsvp=vsvp, nt0=nt0, 
                                                            linearization="akirich", dtype=np.float32))
    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (PPopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(PPopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 10 -r 2 PPop @ x
    %timeit -n 10 -r 2 (PPopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=6.903024673461914 - u^H(Op^Hv)=6.903029441833496
42.1 µs ± 4.14 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
28.8 µs ± 3.96 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-46.652462005615234 - u^H(Op^Hv)=-46.652435302734375
38 µs ± 2.33 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
25.3 µs ± 1.81 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [58]:
theta = np.arange(0, 30., 5.).astype(np.float32)

for explicit in (True,): # False):
    for vsvp in (0.5, np.ones(nt0)):
        print(explicit, vsvp)
        PPop = pylops.avo.prestack.PrestackLinearModelling(
            wav, theta, vsvp=vsvp, nt0=nt0, linearization="akirich", explicit=explicit,
        )
        x = np.zeros(3*nt0)
        y = PPop @ x

        PPopjax = JaxOperator(pylops.avo.prestack.PrestackLinearModelling(jnp.array(wav).astype(jnp.float32), 
                                                                          jnp.array(theta).astype(jnp.float32), 
                                                                          vsvp=vsvp, nt0=nt0, 
                                                                          linearization="akirich", 
                                                                          explicit=explicit))
        xjnp = jnp.array(x).astype(jnp.float32)
        yjnp = (PPopjax @ xjnp).block_until_ready()
        print('y=', type(yjnp))

        dottest(PPopjax, backend='jax', verb=True, atol=1e-2)

        %timeit -n 10 -r 2 PPop @ x
        %timeit -n 10 -r 2 (PPopjax @ xjnp).block_until_ready()

True 0.5
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=23.199710845947266 - u^H(Op^Hv)=23.199716567993164
1.16 ms ± 21.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
253 µs ± 53.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
True [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 

In [59]:
theta = np.arange(0, 30., 5.).astype(np.float32)

for vsvp in (0.5, np.ones(nt0)):
    print(vsvp)
    x = np.zeros((nt0, 3))
    PPop = pylops.avo.prestack.PrestackWaveletModelling(
        x, theta, vsvp=vsvp, nwav=nt//4, wavc=nt//8, linearization="akirich", 
    )
    wav = np.zeros(nt//4)
    y = PPop @ wav

    PPopjax = JaxOperator(pylops.avo.prestack.PrestackWaveletModelling(
        jnp.array(x).astype(jnp.float32), jnp.array(theta).astype(jnp.float32), 
        vsvp=vsvp, nwav=nt//4, wavc=nt//8, linearization="akirich"))
    wavjnp = jnp.array(wav).astype(jnp.float32)
    yjnp = (PPopjax @ wavjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(PPopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 10 -r 2 PPop @ wav
    %timeit -n 10 -r 2 (PPopjax @ wavjnp).block_until_ready()

0.5
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
44.2 µs ± 5.11 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
50.8 µs ± 3.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 

## PressureToVelocity

In [60]:
Pop = pylops.waveeqprocessing.PressureToVelocity(
    100, 20, 0.004, 1, 1000, 1000, nffts=(256, 256), ntaper=5, topressure=False)
x = np.ones(100*20)
y = Pop @ x

Popjax = JaxOperator(pylops.waveeqprocessing.PressureToVelocity(
    100, 20, 0.004, 1, 1000, 1000, nffts=(256, 256), ntaper=5, topressure=False))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Popjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Popjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

  OBL = Kz / (rho * np.abs(F))
  OBL = Kz / (rho * np.abs(F))
  return getattr(self.aval, name).fun(self, *args, **kwargs)


y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=(4.5178489926911425e-06+1.710300523427577e-08j) - u^H(Op^Hv)=(4.517860361374915e-06+1.7101630334082074e-08j)
2.06 ms ± 283 ns per loop (mean ± std. dev. of 2 runs, 10 loops each)
700 µs ± 10.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## UpDownComposition

In [61]:
Pop = pylops.waveeqprocessing.UpDownComposition2D(
    100, 20, 0.004, 1, 1000, 1000, nffts=(256, 256), ntaper=5, dtype=np.complex64)
x = np.ones(2*100*20)
y = Pop @ x

Popjax = JaxOperator(pylops.waveeqprocessing.UpDownComposition2D(
    100, 20, 0.004, 1, 1000, 1000, nffts=(256, 256), ntaper=5, dtype=np.complex64))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Popjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Popjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=(31.831830978393555-5.285550628286728e-07j) - u^H(Op^Hv)=(31.83177947998047+3.0859558819429367e-07j)
4.76 ms ± 102 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
1.6 ms ± 44.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [62]:
Pop = pylops.waveeqprocessing.UpDownComposition3D(
    30, (20, 10), 0.004, (1, 1), 1000, 1000, nffts=(128, 128, 128), ntaper=5, fftengine='numpy', dtype=np.complex64)
x = np.ones(2*30*20*10)
y = Pop @ x

Popjax = JaxOperator(pylops.waveeqprocessing.UpDownComposition3D(
    30, (20, 10), 0.004, (1, 1), 1000, 1000, nffts=(128, 128, 128), ntaper=5, fftengine='numpy', dtype=np.complex64))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Popjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Popjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 4 -r 2 Pop @ x
%timeit -n 4 -r 2 (Popjax @ xjnp).block_until_ready()

  OBL = Kz / (rho * np.abs(F))
  OBL = Kz / (rho * np.abs(F))


y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=(29.866849899291992-2.5170536446239566e-06j) - u^H(Op^Hv)=(29.86686897277832-1.1114821063529234e-06j)
382 ms ± 18.5 ms per loop (mean ± std. dev. of 2 runs, 4 loops each)
129 ms ± 3.42 ms per loop (mean ± std. dev. of 2 runs, 4 loops each)


## BlendingContinuous

In [63]:
nt, nr, ns = 100, 2, 4

for shiftall in (True, False):
    Bop = pylops.waveeqprocessing.BlendingContinuous(nt, nr, ns, 0.004, 0.002*np.ones(ns), 
                                                     shiftall=shiftall, dtype=np.float32)
    x = np.ones(Bop.shape[1])
    y = Bop @ x

    Bopjax = JaxOperator(pylops.waveeqprocessing.BlendingContinuous(nt, nr, ns, 0.004, 0.002*np.ones(ns), 
                                                                    shiftall=shiftall, dtype=np.float32))
    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (Bopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(Bopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 4 -r 2 Bop @ x
    %timeit -n 4 -r 2 (Bopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-41.90353775024414 - u^H(Op^Hv)=-41.90357208251953
231 µs ± 31.1 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
66.8 µs ± 13.4 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-5.632740497589111 - u^H(Op^Hv)=-5.63276481628418
510 µs ± 54.3 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
169 µs ± 20 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)


## BlendingGroup

In [64]:
nt, nr, ns = 100, 2, 8
Bop = pylops.waveeqprocessing.BlendingGroup(nt, nr, ns, 0.004, 0.002*np.ones((2, 4)), group_size=2, n_groups=4, 
                                            dtype=np.float32)
x = np.ones(Bop.shape[1])
y = Bop @ x

Bopjax = JaxOperator(pylops.waveeqprocessing.BlendingGroup(nt, nr, ns, 0.004, 0.002*np.ones((2, 4)), group_size=2, n_groups=4, 
                                                           dtype=np.float32))
xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Bopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Bopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 4 -r 2 Bop @ x
%timeit -n 4 -r 2 (Bopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-28.674692153930664 - u^H(Op^Hv)=-28.6746826171875
345 µs ± 62 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
69.4 µs ± 14 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)


## BlendingHalf

In [65]:
nt, nr, ns = 100, 2, 8
Bop = pylops.waveeqprocessing.BlendingHalf(nt, nr, ns, 0.004, 0.002*np.ones((2, 4)), group_size=2, n_groups=4, 
                                            dtype=np.float32)
x = np.ones(Bop.shape[1])
y = Bop @ x

Bopjax = JaxOperator(pylops.waveeqprocessing.BlendingHalf(nt, nr, ns, 0.004, 0.002*np.ones((2, 4)), group_size=2, n_groups=4, 
                                                          dtype=np.float32))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Bopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Bopjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 4 -r 2 Bop @ x
%timeit -n 4 -r 2 (Bopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=2.5328550338745117 - u^H(Op^Hv)=2.5328428745269775
295 µs ± 20.7 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
71.5 µs ± 12.2 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)


## MDC

In [66]:
for usematmul in (True, False):
    # Create operator
    nt = 51
    MDCop = pylops.waveeqprocessing.MDC(
        (np.ones((20, 5, 4)) + 1j* np.ones((20, 5, 4))).astype(np.complex64),
        nt=2 * nt - 1,
        nv=2,
        dt=0.004,
        dr=1.0,
        usematmul=usematmul
    )
    x = np.ones(MDCop.shape[1])
    y = MDCop @ x

    MDCopjax = JaxOperator(pylops.waveeqprocessing.MDC(
        jnp.array((np.ones((20, 5, 4)) + 1j* np.ones((20, 5, 4)))).astype(np.complex64),
        nt=2 * nt - 1,
        nv=2,
        dt=0.004,
        dr=1.0,
        usematmul=usematmul
    ))

    xjnp = jnp.array(x).astype(jnp.float32)
    yjnp = (MDCopjax @ xjnp).block_until_ready()
    print('y=', type(yjnp))

    dottest(MDCopjax, backend='jax', verb=True, atol=1e-2)

    %timeit -n 4 -r 2 MDCop @ x
    %timeit -n 4 -r 2 (MDCopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=1.1786248683929443 - u^H(Op^Hv)=1.1786240339279175
209 µs ± 25.5 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
79.5 µs ± 8.98 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=3.2770657539367676 - u^H(Op^Hv)=3.2770698070526123
260 µs ± 27 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
75.9 µs ± 5.85 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)


## PhaseShift

In [67]:
nt, nx = 201, 31
dt, dx = 0.004, 1
vel = 1500.0
zprop = 100
pad = 5
freq = np.fft.rfftfreq(nt, dt)
kx = np.fft.fftshift(np.fft.fftfreq(nx + 2 * pad, dx))

Pop = pylops.waveeqprocessing.PhaseShift(vel, zprop, nt, freq, kx, dtype=np.float32)
x = np.ones(Pop.shape[1])
y = Pop @ x

Popjax = JaxOperator(pylops.waveeqprocessing.PhaseShift(vel, zprop, nt, freq, kx, dtype=np.float32))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Popjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Popjax, backend='jax', verb=True, atol=1e-2)

%timeit -n 4 -r 2 Pop @ x
%timeit -n 4 -r 2 (Popjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=-50.261837005615234 - u^H(Op^Hv)=-50.2617073059082
649 µs ± 22.4 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
246 µs ± 8.43 µs per loop (mean ± std. dev. of 2 runs, 4 loops each)
