# Pylops - basic linear operators with JAX

### Author: M.Ravasi

In this notebook I will test some of the basic linear operators implemented in *Pylops* with JAX backend

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import pylops
import jax.numpy as jnp
import jax

from pylops.utils import dottest
from pylops import LinearOperator
from pylops import JaxOperator
from pylops.basicoperators import *
from pylops.signalprocessing import *
from pylops.optimization.basic import cgls, lsqr
from pylops.optimization.leastsquares import *

os.environ["JAX_PYLOPS"] = '1'



## Matrix Multiplication

In [2]:
ny, nx = 4, 4

G = np.random.normal(0, 1, (ny,nx)).astype('float32')
Gop = MatrixMult(G, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Gop @ x

Gopjax = JaxOperator(MatrixMult(jnp.array(G), dtype='float32'))
xjnp = jnp.ones(nx, dtype='float32')
yjnp = (Gopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Gopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

%timeit -n 10 -r 2 Gop @ x
%timeit -n 10 -r 2 (Gopjax @ xjnp).block_until_ready()

I0000 00:00:1719163437.621939       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


y= <class 'jaxlib.xla_extension.ArrayImpl'>
Dot test passed, v^H(Opu)=3.303278684616089 - u^H(Op^Hv)=3.303278684616089
8.04 µs ± 2.95 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
41.6 µs ± 7.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [3]:
# Using ad for rmatvec
xxjnp = Gopjax.H @ yjnp
xxjnpad = Gopjax._rmatvecad(xjnp, yjnp)

xxjnp, xxjnpad

(Array([8.738094 , 2.6162288, 7.417057 , 4.5167336], dtype=float32),
 (Array([8.738094 , 2.6162288, 7.417057 , 4.516734 ], dtype=float32),))

In [4]:
# Inversion
xcgls = cgls(Gop, y, x0=np.zeros(nx), 
             niter=100, tol=1e-10, show=True)[0]

xcglsjnp = cgls(Gopjax, yjnp, x0=jnp.zeros(nx), 
                niter=100, tol=1e-10, show=True)[0]

xcgls, xcglsjnp

CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00	tol = 1.000000e-10	niter = 100
-----------------------------------------------------------------

    Itn          x[0]              r1norm         r2norm
     1        1.2394e+00         8.8929e-01     8.8929e-01
     2        1.1484e+00         1.8027e-01     1.8027e-01
     3        1.0727e+00         7.5803e-02     7.5803e-02
     4        1.0000e+00         2.4448e-14     2.4448e-14

Iterations = 4        Total time (s) = 0.00
-----------------------------------------------------------------

CGLS
-----------------------------------------------------------------
The Operator Op has 4 rows and 4 cols
damp = 0.000000e+00	tol = 1.000000e-10	niter = 100
-----------------------------------------------------------------

    Itn          x[0]              r1norm         r2norm
     1        1.2394e+00         8.8929e-01     8.8929e-01
     2        1.1484e+00 

(array([1.00000014, 1.00000006, 0.99999978, 1.00000002]),
 Array([1.       , 1.       , 0.9999999, 0.9999999], dtype=float32))

## Identity

In [5]:
ny, nx = 5, 5 

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()
y1jnp = (Iopjax.H @ yjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))
print('y1  = ', y1jnp, type(y1jnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-3.3537943363189697 - u^H(Op^Hv)=-3.3537943363189697
x  =  [0. 1. 2. 3. 4.]
y  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
y1  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.49 µs ± 1.48 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
54.2 µs ± 18.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [6]:
ny, nx = 5, 7

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.44220054149627686 - u^H(Op^Hv)=0.44220054149627686
x  =  [0. 1. 2. 3. 4. 5. 6.]
y  =  [0. 1. 2. 3. 4.] <class 'jaxlib.xla_extension.ArrayImpl'>
7.76 µs ± 419 ns per loop (mean ± std. dev. of 2 runs, 10 loops each)
55.8 µs ± 19.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [7]:
ny, nx = 7, 5

Iop = Identity(ny, nx, dtype='float32')
x = np.arange(nx, dtype='float32')
y = Iop @ x

Iopjax = JaxOperator(Identity(ny, nx, dtype='float32'))
xjnp = jnp.arange(nx, dtype='float32')
yjnp = (Iopjax @ xjnp).block_until_ready()

dottest(Iopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('x  = ', x)
print('y  = ', yjnp, type(yjnp))

%timeit -n 10 -r 2 Iop @ x
%timeit -n 10 -r 2 (Iopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-0.8981446027755737 - u^H(Op^Hv)=-0.8981446027755737
x  =  [0. 1. 2. 3. 4.]
y  =  [0. 1. 2. 3. 4. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
9.84 µs ± 1.57 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
46.4 µs ± 10.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Diagonal

In [8]:
nx = 10

d = np.ones(nx, dtype='float32')
Dop = Diagonal(d, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Dop @ x

djnp = jnp.ones(nx, dtype='float32')
Dopjax = JaxOperator(Diagonal(djnp, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Dopjax @ xjnp).block_until_ready()

dottest(Dopjax, nx, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Dop @ x
%timeit -n 10 -r 2 (Dopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-3.027193546295166 - u^H(Op^Hv)=-3.027193546295166
y= [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.62 µs ± 1.44 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
44.9 µs ± 12.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [9]:
# Complex numbers 
nx = 10

d = np.ones(nx, dtype='float32') + 1j*np.ones(nx, dtype='float32')
Dop = Diagonal(d, dtype='float32')
x = np.ones(nx, dtype='float32')+ 1j*np.ones(nx, dtype='float32')
y = Dop @ x

djnp = jnp.ones(nx, dtype='float32')
Dopjax = JaxOperator(Diagonal(djnp, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Dopjax @ xjnp).block_until_ready()

dottest(Dopjax, nx, nx, complexflag=2, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Dop @ x
%timeit -n 10 -r 2 (Dopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=(-0.6456664800643921+3.073392868041992j) - u^H(Op^Hv)=(-0.6456664800643921+3.073392868041992j)
y= [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.95 µs ± 1.64 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
47.5 µs ± 14.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Zero

In [10]:
ny, nx = 5, 5

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.31 µs ± 1.44 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.4 µs ± 12.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [11]:
ny, nx = 7, 5

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.29 µs ± 1.69 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
52 µs ± 20.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [12]:
ny, nx = 5, 7

Zop = Zero(ny, nx, dtype='float32')
x = np.ones(nx, dtype='float32')
y = Zop @ x

Zopjax = JaxOperator(Zero(ny, nx, dtype='float32'))
xjnp = jnp.ones(nx)
yjnp = (Zopjax @ xjnp).block_until_ready()

dottest(Zopjax, ny, nx, backend='jax', verb=True, atol=1e-3)

print('y=', yjnp, type(yjnp))

%timeit -n 10 -r 2 Zop @ x
%timeit -n 10 -r 2 (Zopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=0.0 - u^H(Op^Hv)=0.0
y= [0. 0. 0. 0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'>
6.64 µs ± 1.71 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.8 µs ± 12 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Sum

In [13]:
ny, nx = 5, 7

Sop = Sum(dims=(ny, nx), axis=0)
x   = (np.arange(ny*nx)).reshape(ny, nx)
y = Sop @ x

Sopjax = JaxOperator(Sum(dims=(ny, nx), axis=0, dtype='float32'))
xjnp = (jnp.arange(ny*nx)).reshape(ny, nx)
yjnp = (Sopjax @ xjnp).block_until_ready()

dottest(Sopjax, nx, ny*nx, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=2.447092056274414 - u^H(Op^Hv)=2.447092056274414
y= <class 'jaxlib.xla_extension.ArrayImpl'>
9.77 µs ± 1.95 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
47.6 µs ± 12.7 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Transpose

In [14]:
ny, nx = 20, 40
dims = (ny, nx)

Top = Transpose(dims=dims, axes=(1,0))
x = np.arange(ny*nx).reshape(dims)
y = Top @ x

Topjax = JaxOperator(Transpose(dims=dims, axes=(1,0), dtype='float32'))
xjnp = (jnp.arange(ny*nx)).reshape(dims)
yjnp = (Topjax @ xjnp).block_until_ready()

dottest(Topjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Top @ x
%timeit -n 10 -r 2 (Topjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-6.011893272399902 - u^H(Op^Hv)=-6.011882305145264
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.5 µs ± 2.35 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
85 µs ± 19.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Flip

In [15]:
nt = 10

Fop = Flip(nt)
x = np.arange(nt)
y = Fop @ x

Fopjax = JaxOperator(Flip(nt, dtype='float32'))
xjnp = jnp.arange(nt)
yjnp = (Fopjax @ xjnp).block_until_ready()

dottest(Fopjax, nt, nt, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-0.7547396421432495 - u^H(Op^Hv)=-0.7547395825386047
y= <class 'jaxlib.xla_extension.ArrayImpl'>
11.7 µs ± 2.22 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.5 µs ± 12 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [16]:
nt, nx = 10, 5
x = np.outer(np.arange(nt), np.ones(nx))

for axis in (0, 1):
    Fop = Flip(dims=(nt, nx), axis=axis)
    y = Fop @ x

    Fopjax = JaxOperator(Flip(dims=(nt, nx), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.arange(nt), jnp.ones(nx))
    yjnp = (Fopjax @ xjnp).block_until_ready()

    dottest(Fopjax, nt*nx, nt*nx, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Fop @ x
    %timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-15.899504661560059 - u^H(Op^Hv)=-15.899502754211426
y= <class 'jaxlib.xla_extension.ArrayImpl'>
17.2 µs ± 6.33 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
87.1 µs ± 20.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=4.159647464752197 - u^H(Op^Hv)=4.1596479415893555
y= <class 'jaxlib.xla_extension.ArrayImpl'>
11.2 µs ± 2.37 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
81.8 µs ± 17.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [17]:
nt, nx, ny = 2, 3, 4
x = np.outer(np.ones(nt), np.ones(nx))[:, :, np.newaxis] * np.arange(ny)

for axis in (0, 1, 2):
    Fop = Flip(dims=(nt, nx, ny), axis=axis)
    y = Fop @ x

    Fopjax = JaxOperator(Flip(dims=(nt, nx, ny), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.ones(nt), jnp.ones(nx))[:, :, jnp.newaxis] * jnp.arange(ny)
    yjnp = (Fopjax @ xjnp).block_until_ready()

    dottest(Fopjax, nt*nx*ny, nt*nx*ny, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Fop @ x
    %timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-2.8223679065704346 - u^H(Op^Hv)=-2.8223681449890137
y= <class 'jaxlib.xla_extension.ArrayImpl'>
12 µs ± 1.49 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
86 µs ± 20.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=-0.6433626413345337 - u^H(Op^Hv)=-0.6433628797531128
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.2 µs ± 2.43 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
70.3 µs ± 8.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=2.4860689640045166 - u^H(Op^Hv)=2.4860692024230957
y= <class 'jaxlib.xla_extension.ArrayImpl'>
15 µs ± 3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
80 µs ± 10.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Roll

In [18]:
nt, nx = 10, 5
x = np.outer(np.arange(nt), np.ones(nx))


for axis in (0, 1):
    
    Rop = Roll((nt, nx), axis=axis)
    y = Rop @ x

    Ropjax = JaxOperator(Roll((nt, nx), axis=axis, dtype='float32'))
    xjnp = jnp.outer(jnp.arange(nt), jnp.ones(nx))
    yjnp = (Ropjax @ xjnp).block_until_ready()

    dottest(Ropjax, nt*nx, nt*nx, backend='jax', verb=True, atol=1e-3)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Rop @ x
    %timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-2.4878180027008057 - u^H(Op^Hv)=-2.487818717956543
y= <class 'jaxlib.xla_extension.ArrayImpl'>
21.7 µs ± 3.72 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
88.1 µs ± 24.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=16.584728240966797 - u^H(Op^Hv)=16.584728240966797
y= <class 'jaxlib.xla_extension.ArrayImpl'>
46.3 µs ± 12 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
111 µs ± 43.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [19]:
nt = 10

Rop = Roll(nt, shift=-5)
x = np.arange(nt)
y = Rop @ x

Ropjax = JaxOperator(Roll(nt, shift=-5, dtype='float32'))
xjnp = jnp.arange(nt)
yjnp = (Ropjax @ xjnp).block_until_ready()

dottest(Ropjax, nt, nt, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Rop @ x
%timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=2.77858829498291 - u^H(Op^Hv)=2.77858829498291
y= <class 'jaxlib.xla_extension.ArrayImpl'>
22.2 µs ± 4.12 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
59.2 µs ± 17.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Pad

In [20]:
dims = 10
pad = (10, 3)

Pop = Pad(dims, pad)
x = np.arange(dims)+1.
y = Pop @ x

Popjax = JaxOperator(Pad(dims, pad, dtype='float32'))
xjnp = jnp.arange(dims)+1.
yjnp = (Popjax @ xjnp).block_until_ready()

dottest(Popjax, dims+pad[0]+pad[1], dims, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=1.8868082761764526 - u^H(Op^Hv)=1.8868082761764526
y= <class 'jaxlib.xla_extension.ArrayImpl'>
36.1 µs ± 3.62 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
62.3 µs ± 24.8 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


In [21]:
dims = (5, 4)
pad = ((1, 0), (3, 4))

Pop = Pad(dims, pad)
x = (np.arange(np.prod(np.array(dims)))+1.).reshape(dims)
y = Pop @ x

Popjax = JaxOperator(Pad(dims, pad, dtype='float32'))
xjnp = (jnp.arange(jnp.prod(jnp.array(dims)))+1.).reshape(dims)
yjnp = (Popjax @ xjnp).block_until_ready()

dottest(Popjax, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Pop @ x
%timeit -n 10 -r 2 (Popjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-0.5775524973869324 - u^H(Op^Hv)=-0.5775524973869324
y= <class 'jaxlib.xla_extension.ArrayImpl'>
60.2 µs ± 14.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
116 µs ± 26.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Symmetrize

In [22]:
nt = 10

Sop = Symmetrize(nt)
x = np.arange(nt)+1
y = Sop @ x

Sopjax = JaxOperator(Symmetrize(nt, dtype='float32'))
xjnp = jnp.arange(nt)+1
yjnp = (Sopjax @ xjnp).block_until_ready()

dottest(Sopjax, backend='jax', verb=True, atol=1e-6)

print('y=', type(yjnp))

%timeit -n 10 -r 2 Sop @ x
%timeit -n 10 -r 2 (Sopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-7.6267523765563965 - u^H(Op^Hv)=-7.626751899719238
y= <class 'jaxlib.xla_extension.ArrayImpl'>
19.6 µs ± 1.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
62.5 µs ± 27.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Restriction

In [23]:
N=200

perc_subsampling=0.4
Nsub = int(np.round(N*perc_subsampling))
iava = np.sort(np.random.permutation(np.arange(N))[:Nsub])

for inplace in (True, False):
    Rop = Restriction(N, iava, inplace=inplace, dtype='float64')
    x = np.zeros(N)
    y = Rop @ x

    Ropjax = JaxOperator(Restriction(N, iava, inplace=inplace, dtype='float32'))
    xjnp = jnp.zeros(N)
    yjnp = (Ropjax @ xjnp).block_until_ready()

    dottest(Ropjax, backend='jax', verb=True, atol=1e-6)

    print('y=', type(yjnp))

    %timeit -n 10 -r 2 Rop @ x
    %timeit -n 10 -r 2 (Ropjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=10.850922584533691 - u^H(Op^Hv)=10.850922584533691
y= <class 'jaxlib.xla_extension.ArrayImpl'>
12.2 µs ± 3.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
61.6 µs ± 12.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
Dot test passed, v^H(Opu)=11.268792152404785 - u^H(Op^Hv)=11.268792152404785
y= <class 'jaxlib.xla_extension.ArrayImpl'>
21 µs ± 7.44 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
63.9 µs ± 22.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## Regression

In [24]:
# parameters
N = 30
t = np.arange(N, dtype='float64')

LRop = LinearRegression(jnp.array(t), dtype=None)
x = np.array([1., 2.])
y = LRop @ x

LRopjax = JaxOperator(LinearRegression(t, dtype=None))
xjnp = jnp.array(x)
yjnp = (LRopjax @ xjnp).block_until_ready()

dottest(LRopjax, N, 2, backend='jax', verb=True, atol=1e-3)

print('y=', type(yjnp))

%timeit -n 10 -r 2 LRop @ x
%timeit -n 10 -r 2 (LRopjax @ xjnp).block_until_ready()

Dot test passed, v^H(Opu)=-15.633894920349121 - u^H(Op^Hv)=-15.633901596069336
y= <class 'jaxlib.xla_extension.ArrayImpl'>
240 µs ± 46.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
34.1 µs ± 2.46 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


  u = randn(nc, backend).astype(rdtype)
  v = randn(nr, backend).astype(rdtype)


In [25]:
# CGLS solver
xcglsjnp = cgls(LRopjax, yjnp, jnp.zeros(2), damp=1e-10, niter=10 ,show=0)[0]

print('cgls solution xlsqr=', xcglsjnp)

cgls solution xlsqr= [0.99999994 2.        ]


## First Derivative

In [26]:
nx = 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        for order in (3, 5):
            print(kind, edge, order)
            D1op = FirstDerivative(nx, edge=edge, kind=kind, 
                                   order=order if kind == 'centered' else 3, 
                                   dtype='float32')
            x = np.ones(nx)
            y = D1op @ x

            D1opjax = JaxOperator(FirstDerivative(nx, edge=edge, kind=kind, 
                                                  order=order if kind == 'centered' else 3, 
                                                  dtype='float32'))
            xjnp = jnp.array(x)
            yjnp = (D1opjax @ xjnp).block_until_ready()

            dottest(D1opjax, nx, nx, backend='jax', verb=True, atol=1e-3)
            print('y=', type(yjnp))

            %timeit -n 10 -r 2 D1op @ x
            %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 3
Dot test passed, v^H(Opu)=-1.03773033618927 - u^H(Op^Hv)=-1.0377304553985596
y= <class 'jaxlib.xla_extension.ArrayImpl'>
15.5 µs ± 3.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
45.8 µs ± 3.16 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward False 5
Dot test passed, v^H(Opu)=-0.6470524072647095 - u^H(Op^Hv)=-0.6470521688461304
y= <class 'jaxlib.xla_extension.ArrayImpl'>
10.6 µs ± 2.02 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
32.5 µs ± 1.98 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 3
Dot test passed, v^H(Opu)=-9.530404090881348 - u^H(Op^Hv)=-9.530405044555664
y= <class 'jaxlib.xla_extension.ArrayImpl'>
26.1 µs ± 1.39 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
68.3 µs ± 28.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=11.062045097351074 - u^H(Op^Hv)=11.062047004699707
y= <class 'jaxlib.xla_extension.ArrayImpl'>
15.4 µs ± 3.44 µs per loop 

In [27]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = FirstDirectionalDerivative((ny, nx), v=0.5 * np.ones(2), edge=edge, kind=kind, 
                                          dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(FirstDirectionalDerivative((ny, nx), v=0.5 * np.ones(2), edge=edge, kind=kind, 
                                                         dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=1.154794692993164 - u^H(Op^Hv)=1.1547937393188477
y= <class 'jaxlib.xla_extension.ArrayImpl'>
91.8 µs ± 9.84 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
46.9 µs ± 9.91 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=5.504953861236572 - u^H(Op^Hv)=5.504956245422363
y= <class 'jaxlib.xla_extension.ArrayImpl'>
78.8 µs ± 20.4 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
30.3 µs ± 3.29 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=-10.018988609313965 - u^H(Op^Hv)=-10.018983840942383
y= <class 'jaxlib.xla_extension.ArrayImpl'>
50.3 µs ± 5.76 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
27.1 µs ± 4.73 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=-1.1075350046157837 - u^H(Op^Hv)=-1.1075341701507568
y= <class 'jaxlib.xla_extension.ArrayImpl'>
75.1 µs ± 22.2 µs per loo

## Second Derivative

In [28]:
nx = 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = SecondDerivative(nx, edge=edge, kind=kind, 
                                dtype='float32')
        x = np.ones(nx)
        y = D1op @ x

        D1opjax = JaxOperator(SecondDerivative(nx, edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, nx, nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=0.6915369033813477 - u^H(Op^Hv)=0.6915367841720581
y= <class 'jaxlib.xla_extension.ArrayImpl'>
17.7 µs ± 4.48 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
36.5 µs ± 5.41 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=7.171112060546875 - u^H(Op^Hv)=7.171111583709717
y= <class 'jaxlib.xla_extension.ArrayImpl'>
17.1 µs ± 4.11 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
63.8 µs ± 18.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=1.8179728984832764 - u^H(Op^Hv)=1.8179726600646973
y= <class 'jaxlib.xla_extension.ArrayImpl'>
16.3 µs ± 3.87 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
38.7 µs ± 7.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=-2.891242504119873 - u^H(Op^Hv)=-2.891242504119873
y= <class 'jaxlib.xla_extension.ArrayImpl'>
34.9 µs ± 3.27 µs per loop (m

## Laplacian

In [29]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = Laplacian((ny, nx), edge=edge, kind=kind, dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(Laplacian((ny, nx), edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=17.758832931518555 - u^H(Op^Hv)=17.758827209472656
y= <class 'jaxlib.xla_extension.ArrayImpl'>
65.8 µs ± 13 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
24.8 µs ± 3.85 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=13.350592613220215 - u^H(Op^Hv)=13.350592613220215
y= <class 'jaxlib.xla_extension.ArrayImpl'>
98.2 µs ± 39.2 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
28.4 µs ± 6.3 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=13.105086326599121 - u^H(Op^Hv)=13.1051025390625
y= <class 'jaxlib.xla_extension.ArrayImpl'>
64 µs ± 3.36 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
35.8 µs ± 14.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=118.86038208007812 - u^H(Op^Hv)=118.86035919189453
y= <class 'jaxlib.xla_extension.ArrayImpl'>
88.8 µs ± 5.62 µs per loop (mean 

## Gradient

In [30]:
ny, nx = 21, 11
   
for kind in ('forward', 'centered', 'backward'):
    for edge in (False, True):
        print(kind, edge, order)
        D1op = Gradient((ny, nx), edge=edge, kind=kind, dtype='float32')
        x = np.ones(ny*nx)
        y = D1op @ x

        D1opjax = JaxOperator(Gradient((ny, nx), edge=edge, kind=kind, dtype='float32'))
        xjnp = jnp.array(x)
        yjnp = (D1opjax @ xjnp).block_until_ready()

        dottest(D1opjax, 2*ny*nx, ny*nx, backend='jax', verb=True, atol=1e-3)
        print('y=', type(yjnp))

        %timeit -n 10 -r 2 D1op @ x
        %timeit -n 10 -r 2 (D1opjax @ xjnp).block_until_ready()

forward False 5
Dot test passed, v^H(Opu)=14.915897369384766 - u^H(Op^Hv)=14.915921211242676
y= <class 'jaxlib.xla_extension.ArrayImpl'>
29 µs ± 3.79 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
32.8 µs ± 9.88 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
forward True 5
Dot test passed, v^H(Opu)=26.62845802307129 - u^H(Op^Hv)=26.628450393676758
y= <class 'jaxlib.xla_extension.ArrayImpl'>
55.7 µs ± 14.9 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
26.8 µs ± 4.52 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered False 5
Dot test passed, v^H(Opu)=-17.402027130126953 - u^H(Op^Hv)=-17.402027130126953
y= <class 'jaxlib.xla_extension.ArrayImpl'>
83.3 µs ± 10.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
25.2 µs ± 3.63 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
centered True 5
Dot test passed, v^H(Opu)=30.421144485473633 - u^H(Op^Hv)=30.4211368560791
y= <class 'jaxlib.xla_extension.ArrayImpl'>
73.5 µs ± 29.8 µs per loop (m

## Causal integration

In [31]:
nx = 1000000

Cop = CausalIntegration(nx, dtype='float32')

x = np.ones(nx, dtype='float32')
y = Cop @ x

Copjax = JaxOperator(Cop)

xjnp = jnp.ones(nx)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
4.03 ms ± 26.6 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
3 ms ± 524 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)


## VStack

In [32]:
D2hop = SecondDerivative(dims=[11,21], axis=1, dtype='float64')
D2vop = SecondDerivative(dims=[11,21], axis=0, dtype='float64')

D1hopjax = JaxOperator(FirstDerivative(dims=[11,21], axis=1, dtype='float32'))
D1vopjax = JaxOperator(FirstDerivative(dims=[11,21], axis=0, dtype='float32'))
    
Vstackop = VStack([D1hopjax,D1vopjax])
xjnp = jnp.ones(11*21)
yjnp = (Vstackop @ xjnp).block_until_ready()

dottest(Vstackop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=-10.478123664855957 - u^H(Op^Hv)=-10.478117942810059
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## HStack

In [33]:
Hstackop = HStack([D1hopjax,D1vopjax])
xjnp = jnp.ones(2*11*21)
yjnp = (Hstackop @ xjnp).block_until_ready()

dottest(Hstackop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=-0.7496676445007324 - u^H(Op^Hv)=-0.7496649622917175
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## BlockDiag

In [34]:
Blockop = BlockDiag([D1hopjax,D1vopjax])
xjnp = jnp.ones(2*11*21)
yjnp = (Blockop @ xjnp).block_until_ready()

dottest(Blockop, backend='jax', verb=True, atol=1e-3)
print('y=', type(yjnp))

Dot test passed, v^H(Opu)=-32.236759185791016 - u^H(Op^Hv)=-32.23677062988281
y= <class 'jaxlib.xla_extension.ArrayImpl'>


## FFT

In [35]:
nx = 100000

Fop = FFT(nx, dtype='float32')

x = np.ones(nx, dtype='float32')
y = Fop @ x

Fopjax = JaxOperator(Fop)

xjnp = jnp.ones(nx, dtype='float32')
yjnp = (Fopjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

%timeit -n 10 -r 2 Fop @ x
%timeit -n 10 -r 2 (Fopjax @ xjnp).block_until_ready()

y= <class 'jaxlib.xla_extension.ArrayImpl'>
1.32 ms ± 19.5 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)
772 µs ± 44.1 µs per loop (mean ± std. dev. of 2 runs, 10 loops each)




## Convolve1D

In [36]:
nx = 10
offset = 4

h = np.arange(-3, 3)
print(h)

Cop = Convolve1D(nx, h=h, offset=offset, 
                 dtype='float32', 
                 method='direct')
x  = np.zeros(nx)
x[3] = 1
y = Cop @ x

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve1D(nx, h=hjnp, offset=offset, 
                     dtype='float32', 
                     method='direct'))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

#dottest(Copjax, backend='jax', verb=True, atol=1e-3)

#%timeit -n 10 -r 2 Cop @ x
#%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

[-3 -2 -1  0  1  2]
self.convfunc, self.method <function convolve at 0x7f98d100b550> direct
x, h (10,) (9,)
j_convolve
j_convolve
self.convfunc, self.method <function convolve at 0x7f98f161f790> direct
x, h (10,) (9,)
y= <class 'jaxlib.xla_extension.ArrayImpl'>


In [37]:
nt = 301
nx = 20
dt = 0.004
t  = np.arange(nt)*dt

h = np.ones(3)
hcenter=1

Cop = Convolve1D(dims=[nt,nx], h=h, offset=hcenter, axis=0, dtype='float32')
x  = np.zeros((nt, nx))
x[int(nt/2),:] = 1
y = Cop @ x

hjnp = jnp.array(h).astype(jnp.float32)
Copjax = JaxOperator(Convolve1D(dims=[nt,nx], h=hjnp, offset=hcenter, axis=0, dtype='float32'))

xjnp = jnp.array(x).astype(jnp.float32)
yjnp = (Copjax @ xjnp).block_until_ready()
print('y=', type(yjnp))

dottest(Copjax, backend='jax', verb=True, atol=1e-3)

%timeit -n 10 -r 2 Cop @ x
%timeit -n 10 -r 2 (Copjax @ xjnp).block_until_ready()

self.convfunc, self.method functools.partial(<function fftconvolve at 0x7f98d0ffbf70>, axes=0) fft
x, h (301, 20) (3, 1)
j_fftconvolve
j_fftconvolve
self.convfunc, self.method functools.partial(<function fftconvolve at 0x7f98f161f5e0>, axes=0) fft
x, h (301, 20) (3, 1)


ValueError: mapped axes must have same shape; got in1.shape=(301, 20) in2.shape=(3, 1) axes=(0,)

In [49]:
ny, nx = 400, 400
G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
x = jnp.ones(nx, dtype=np.float32)

Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y