In [13]:
import mat73
import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO, BCSR
from jax.experimental.sparse.linalg import spsolve
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import time

#Load data problem
data = mat73.loadmat('RecordedData.mat')
x      = jnp.array(data['x'],      dtype=jnp.float32)
y        = jnp.array(data['y'],      dtype=jnp.float32)
C        = jnp.array(data['C'],      dtype=jnp.float32)
x_circ   = jnp.array(data['x_circ'], dtype=jnp.float32)
y_circ   = jnp.array(data['y_circ'], dtype=jnp.float32)
f_data   = jnp.array(data['f'],      dtype=jnp.float32)


In [14]:
REC_DATA = data['REC_DATA']

In [15]:
numElements = x_circ.size
assert numElements == y_circ.size, \
       "x_circ and y_circ must have the same length"

In [16]:
# Which subset of transmit to use
dwnsmp = 1
tx_include = jnp.arange(0,numElements,dwnsmp)
REC_DATA = REC_DATA[tx_include,:]

In [17]:
# Extract Subset of Signals within Acceptance Angle
numElemLeftRightExcl = 31
elemLeftRightExcl    = jnp.arange(-numElemLeftRightExcl,numElemLeftRightExcl + 1)
elem_include         = jnp.ones((numElements, numElements),dtype=bool)

for tx_element in range(numElements):
    elemLeftRightExclCurrent = (elemLeftRightExcl + tx_element)
    elem_include = elem_include.at[tx_element, elemLeftRightExclCurrent].set(False)


In [18]:
elem_include.shape

(256, 256)

Frequency-Domain Full Waveform Inversion (FWI)

In [19]:
#Parameters for Conjugate Gradient Reconstruction
Niter = 1 #Number of Iterations
momentumFormula = 4 #Momentum Formula for Conjugate Gradient
                    # 0 -- No Momentum (Gradient Descent)
                    # 1 -- Fletcher-Reeves (FR)
                    # 2 -- Polak-Ribiere (PR)
                    # 3 -- Combined FR + PR
                    # 4 -- Hestenes-Stiefel (HS)
stepSizeCalculation = 1 #Which Step Size Calculation:
                        # 1 -- Not Involving Gradient Nor Search Direction
                        # 2 -- Involving Gradient BUT NOT Search Direction
                        # 3 -- Involving Gradient AND Search Direction
c_init = 1480 # Initial Homogeneous Sound Speed [m/s] Guess

In [20]:
#Computational Grid (and Element Placement on Grid) for Reconstruction
dxi = 0.8e-3
xmax = 120e-3
xi = jnp.arange(-xmax,xmax+dxi,dxi) 
yi = xi.copy()
Nxi = xi.size
Nyi = yi.size
[Xi, Yi] = jnp.meshgrid(xi, yi)

xc = x_circ.ravel()   # shape (M,)
yc = y_circ.ravel()   # shape (M,)

x_idx = jnp.argmin(jnp.abs(xi[None, :] - xc[:, None]), axis=1)
y_idx = jnp.argmin(jnp.abs(yi[None, :] - yc[:, None]), axis=1)

ind = y_idx * Nxi + x_idx #Row majo

In [21]:
x_idx.shape

(256,)

In [22]:
#Solver Options for Helmholtz Equation
a0 = 10 #PML Constant
L_PML = 9.0e-3 #Thickness of PML  

#Generate Sources
SRC = jnp.zeros((Nyi, Nxi, tx_include.size), dtype=jnp.float32)

for tx_elmt_idx in range(tx_include.size):
    #Single Element Source
    x_idx_src = x_idx[tx_include[tx_elmt_idx]]
    y_idx_src = y_idx[tx_include[tx_elmt_idx]] 
    SRC = SRC.at[y_idx_src, x_idx_src, tx_elmt_idx].set(1)

KeyboardInterrupt: 

In [None]:
#(Nonlinear) Conjugate Gradient
search_dir = jnp.zeros((Nyi,Nxi)) # Conjugate Gradient Direction
gradient_img_prev = jnp.zeros((Nyi,Nxi)) # Previous Gradient Image
VEL_ESTIM = c_init*jnp.ones((Nyi,Nxi)) # Initial Sound Speed Image [m/s]
SLOW_ESTIM = 1./VEL_ESTIM # Initial Slowness Image [s/m]
crange = jnp.array([1400, 1600]) # For reconstruction display [m/s]

In [None]:
crange.shape

In [25]:
for iter in range(Niter):
    #Step 1: Calculate Gradient/Backprojection
    #(1A) Solve Forward Helmholtz Equation (H is Helmholtz matrix and u is the wavefield)
    t = time.time()
    WVFIELD = solveHelmholtz(xi, yi, VEL_ESTIM, SRC, f_data, a0, L_PML, False)

WVFIELD


BCOO(complex64[90601, 90601], nse=805809)
[1.+0.j 1.+0.j 1.+0.j ... 1.+0.j 1.+0.j 1.+0.j]
[    0     1     2 ... 90598 90599 90600]
[     0      1      2 ... 805807 805808 805809]


KeyboardInterrupt: 

In [None]:
WVFIELD

In [23]:
def solveHelmholtz(x, y, vel, src, f, a0, L_PML, adjoint):
    """
    Solve 2D Helmholtz with a 9-point optimized stencil & PML in JAX.
    Inputs:
      x:        (Nx,) 1D array of grid-x
      y:        (Ny,) 1D array of grid-y
      vel:      (Ny,Nx) wave velocity map
      src:      (Ny,Nx,S) source array (S shots)
      f:        scalar frequency
      a0:       PML strength
      L_PML:    PML thickness
      adjoint:  bool, True→solve H^H, False→solve H
    Returns:
      wvfield:  (Ny,Nx,S) solved wavefields
    """
    # 1) grid & spacing
    h  = jnp.mean(jnp.diff(x))
    gh = jnp.mean(jnp.diff(y))
    g  = gh / h
    Nx, Ny = x.size, y.size
    Ntot   = Nx * Ny

    # 2) wavenumber (elementwise)
    k = 2 * jnp.pi * f / vel  # shape (Ny,Nx)

    # 3) build PML stretching functions
    xe = jnp.linspace(x.min(), x.max(), 2*(Nx-1) + 1)
    ye = jnp.linspace(y.min(), y.max(), 2*(Ny-1) + 1)
    Xe, Ye = jnp.meshgrid(xe, ye, indexing='xy')

    xctr, xspan = 0.5*(x.min()+x.max()), 0.5*(x.max()-x.min())
    yctr, yspan = 0.5*(y.min()+y.max()), 0.5*(y.max()-y.min())

    # clamp-to-zero then square
    def clamp(Z, span):
        return jnp.maximum(jnp.abs(Z - span) + L_PML - span, 0) / L_PML

    sx = 2*jnp.pi*a0*f * (clamp(Xe, xspan)**2)
    sy = 2*jnp.pi*a0*f * (clamp(Ye, yspan)**2)

    signConvention = -1
    ex = 1 + 1j * sx * jnp.sign(signConvention) / (2*jnp.pi*f)
    ey = 1 + 1j * sy * jnp.sign(signConvention) / (2*jnp.pi*f)

    # resize back to (Ny,Nx) by subsampling every-other point
    A =       (ey/ex)[0::2, 1::2]  # shape (Ny,Nx)
    B =       (ex/ey)[1::2, 0::2]
    C = (ex*ey)[0::2, 0::2]

    # 4) optimal stencil constants
    b, d, e = stencilOptParams(jnp.min(vel), jnp.max(vel), f, h, g)

    # 5) flatten‐index helper
    flat = lambda xi, yi: yi * Nx + xi

    # 6) build full grid of indices
    Xg, Yg  = jnp.meshgrid(jnp.arange(Nx), jnp.arange(Ny), indexing='xy')
    rows_all = Yg * Nx + Xg          # flattened row idx

    # boundary mask
    is_bnd = (Xg == 0) | (Xg == Nx-1) | (Yg == 0) | (Yg == Ny-1)
    is_int = ~is_bnd

    # boundary entries: diag=1
    bnd_idx = rows_all[is_bnd]
    bnd_rows = bnd_idx
    bnd_cols = bnd_idx
    bnd_vals = jnp.ones_like(bnd_idx, dtype=jnp.float32)

    # interior points
    Xi, Yi = Xg[is_int], Yg[is_int]
    rowi   = rows_all[is_int]       # shape (M,)
    M      = rowi.size

    inv_h2 = 1.0/(h**2)
    g2     = g**2

    # gather all needed neighbors at once
    Ci   = C[Yi, Xi];   ki2  = (k[Yi,   Xi]**2).real
    Ai   = A[Yi, Xi];   Bi   = B[Yi,   Xi]

    Ai_L = A[Yi, Xi-1];   Bi_L = B[Yi, Xi-1]
    Ai_R = A[Yi, Xi+1];   Bi_R = B[Yi, Xi+1]
    Bi_D = B[Yi-1, Xi];   Ai_DL= A[Yi-1, Xi-1];   Bi_DR = B[Yi-1, Xi+1]
    Ai_UR= A[Yi+1, Xi];   Ai_UL= A[Yi+1, Xi-1]
    Ci_L = C[Yi, Xi-1];   Ci_R = C[Yi, Xi+1]
    Ci_D = C[Yi-1, Xi];   Ci_U = C[Yi+1, Xi]
    Ci_DL= C[Yi-1,Xi-1];  Ci_DR= C[Yi-1,Xi+1]
    Ci_UL= C[Yi+1,Xi-1];  Ci_UR= C[Yi+1,Xi+1]
    
    Bi_DL = B[Yi-1, Xi-1]   # bottom‐left neighbor
    Bi_UL = B[Yi+1, Xi-1]   # top‐left neighbor
    Bi_UR = B[Yi+1, Xi+1]   # top‐right neighbor


    k2_L  = (k[Yi,   Xi-1]**2).real
    k2_R  = (k[Yi,   Xi+1]**2).real
    k2_D  = (k[Yi-1, Xi]**2).real
    k2_U  = (k[Yi+1, Xi]**2).real
    k2_DL = (k[Yi-1,Xi-1]**2).real
    k2_DR = (k[Yi-1,Xi+1]**2).real
    k2_UL = (k[Yi+1,Xi-1]**2).real
    k2_UR = (k[Yi+1,Xi+1]**2).real

    # compute 9 stencil weights
    vc = ((1-d-e)*Ci*ki2 - b*(Ai+Ai_L+Bi/g2+Bi_D/g2)*inv_h2).astype(jnp.float32)
    vl = (( b*Ai_L - ((1-b)/2)*(Bi_L/g2+Bi_DL/g2))*inv_h2 + (d/4)*Ci_L*k2_L)
    vr = (( b*Ai   - ((1-b)/2)*(Bi_R/g2+Bi_DR/g2))*inv_h2 + (d/4)*Ci_R*k2_R)
    vd = (( b*Bi_D/g2 - ((1-b)/2)*(Ai_DL + A[Yi-1,Xi]))*inv_h2 + (d/4)*Ci_D*k2_D)
    vu = (( b*Bi   /g2 - ((1-b)/2)*(Ai_UR+Ai_UL))*inv_h2 + (d/4)*Ci_U*k2_U)
    vbl=(((1-b)/2)*(Ai_DL+Bi_DL/g2)*inv_h2 + (e/4)*Ci_DL*k2_DL)
    vbr=(((1-b)/2)*(Ai   +Bi_DR/g2)*inv_h2 + (e/4)*Ci_DR*k2_DR)
    vtl=(((1-b)/2)*(Ai_UL+Bi_UL/g2)*inv_h2 + (e/4)*Ci_UL*k2_UL)
    vtr=(((1-b)/2)*(Ai_UR+Bi_UR/g2)*inv_h2 + (e/4)*Ci_UR*k2_UR)

    # flattened column indices
    cc = rowi
    cl = flat(Xi-1, Yi)
    cr = flat(Xi+1, Yi)
    cd = flat(Xi,   Yi-1)
    cu = flat(Xi,   Yi+1)
    cbl= flat(Xi-1, Yi-1)
    cbr= flat(Xi+1, Yi-1)
    ctl= flat(Xi-1, Yi+1)
    ctr= flat(Xi+1, Yi+1)

    # stack interior COO entries
    rows_int = jnp.repeat(rowi, 9)
    cols_int = jnp.concatenate([cc, cl, cr, cd, cu, cbl, cbr, ctl, ctr])
    vals_int = jnp.concatenate([vc, vl, vr, vd, vu, vbl, vbr, vtl, vtr])

    # merge boundary + interior
    r_all = jnp.concatenate([bnd_rows, rows_int])
    c_all = jnp.concatenate([bnd_cols, cols_int])
    v_all = jnp.concatenate([bnd_vals, vals_int])

    # assemble COO → BCOO
    H_coo = BCOO((v_all, jnp.stack([r_all, c_all], axis=1)),
                 shape=(Ntot, Ntot))
    print(H_coo)
    
    # 1) Elige la matriz correcta: H o H^*
    if adjoint:
        H_t = H_coo.transpose()  # sigue siendo BCOO
        H_use = BCOO((jnp.conj(H_t.data), H_t.indices), shape=H_t.shape)
    else:
        H_use = H_coo

    # 2) Ahora convertimos UNA SOLA VEZ ese BCOO (ya transpuesto/conjugado si tocaba)
    H_use = BCSR.from_bcoo(H_use)

    # reshape src to Nx*Ny and -1 to get the right shape
    rhs = jnp.reshape(src, (Nx * Ny, -1))
    # print("rhs shape", rhs.shape)
    # Change type to complex64
    rhs = jnp.array(rhs, dtype=jnp.complex64)

    data, indices, indptr = H_use.data, H_use.indices, H_use.indptr
    # print(H_use.indices)
    # print("non empty values", H_use.data[H_use.data != 0].sum())
    # print("Solving system...")
    # sol = spsolve(data, indices, indptr, rhs)
    # tic toc for time analysis

    # start = time.time()
    print(data)
    print(indices)
    print(indptr)
    sol = jnp.stack(
        [spsolve(data, indices, indptr, rhs[:, i]) for i in range(rhs.shape[1])], axis=1
    )
    ##end = time.time()
    # print("Time taken to solve system:", end - start)

    return sol.reshape(Ny, Nx, -1)

In [None]:
def stencilOptParams(vmin,vmax,f,h,g):
#STENCILOPTPARAMS Optimal Params for 9-Point Stencil 
#   INPUTS:
#       vmin = minimum wave velocity [L/T]
#       vmax = maximum wave velocity [L/T]
#       f = frequency [1/T]
#       h = grid spacing in X [L]
#       g = (grid spacing in Y [L])/(grid spacing in X [L])
#   OUTPUTS:
#       b, d, e = optimal params according to Chen/Cheng/Feng/Wu 2013 Paper

    l = 100
    r = 10

    Gmin = vmin / (f * h)
    Gmax = vmax / (f * h)

    # MATLAB m = 1:l → Python 1,2,…,l
    m = jnp.arange(1, l+1)
    # MATLAB n = 1:r → Python 1,2,…,r
    n = jnp.arange(1, r+1)

    # theta = (m-1)*pi/(4*(l-1))
    theta = (m - 1) * jnp.pi / (4 * (l - 1))

    # G = 1./(1/Gmax + ((n-1)/(r-1))*(1/Gmin-1/Gmax));
    G = 1.0 / (
        1.0 / Gmax +
        ((n - 1) / (r - 1)) * (1.0 / Gmin - 1.0 / Gmax)
    )

    # replicated exactly as [TH,GG]=meshgrid(theta,G)
    TH, GG = jnp.meshgrid(theta, G, indexing='xy')

    # the four stencil-estimator summands
    P = jnp.cos(g * 2 * jnp.pi * jnp.cos(TH) / GG)
    Q = jnp.cos(2 * jnp.pi * jnp.sin(TH) / GG)

    S1 = (1 + 1/(g**2)) * (GG**2) * (1 - P - Q + P*Q)
    S2 = (jnp.pi**2)    * (2 - P - Q)
    S3 = (2*jnp.pi**2)  * (1 - P*Q)
    S4 = 2*jnp.pi**2 + (GG**2) * (
         (1 + 1/(g**2))*P*Q - P - Q/(g**2)
    )

    fixB = True
    if fixB:
        b = 5/6
        A = jnp.stack([S2.ravel(), S3.ravel()], axis=1)     # (M,2)
        y = S4.ravel() - b * S1.ravel()                     # (M,)
        # solve (AᵀA)·params = Aᵀy
        ATA    = A.T @ A
        ATy    = A.T @ y
        params = jnp.linalg.solve(ATA, ATy)                 # (2,)
        d, e = params[0], params[1]
    else:
        A = jnp.stack([S1.ravel(), S2.ravel(), S3.ravel()], axis=1)  # (M,3)
        y = S4.ravel()
        ATA    = A.T @ A
        ATy    = A.T @ y
        params = jnp.linalg.solve(ATA, ATy)                          # (3,)
        b, d, e = params[0], params[1], params[2]

    return b, d, e