In [1]:
%load_ext autoreload
%autoreload 1
%matplotlib notebook

In [3]:
import sys, os
import numpy as np
import matplotlib.pyplot as plt

# FENICSx imports
import ufl
import dolfinx
from mpi4py import MPI
from petsc4py.PETSc import ScalarType as dtype

sys.path.append('..')
%aimport mre_pinn

dtype

numpy.complex128

In [100]:
from scipy.interpolate import LinearNDInterpolator

# define a vector-valued function f: R^2 \to R^2
#   NOTE that in reality, I do not have an explicit function like this
f = lambda x: np.stack([
    (x[0] + x[1])**2,
    (x[0] - x[1])**2
], axis=0)

def sample_domain(n):
    x = np.linspace(-1, 1, n)
    return np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2)

# evaluate f on a set of domain samples
n = 100
x = sample_domain(n)
y = f(x.T).T

# create an FEM mesh and basis
n_nodes = 5 # per dimension
mesh = dolfinx.mesh.create_rectangle(
    comm=MPI.COMM_WORLD,
    points=[[-1, -1], [1, 1]],
    n=[n_nodes - 1, n_nodes - 1],
    cell_type=dolfinx.mesh.CellType.triangle,
    diagonal=dolfinx.mesh.DiagonalType.right_left
)
V = dolfinx.fem.VectorFunctionSpace(mesh, ('Lagrange', 1), dim=2)

# interpolate f into the FEM basis
f_h = dolfinx.fem.Function(V)
f_h.interpolate(f)

def eval_func(f_h, x):
    '''Evaluate a dolfinx function on a set of 2D points.'''
    x = np.concatenate([x, np.zeros((len(x), 1))], axis=1)
    tree = dolfinx.geometry.BoundingBoxTree(mesh, mesh.geometry.dim)
    cells = dolfinx.geometry.compute_collisions(tree, x)
    cells = dolfinx.geometry.compute_colliding_cells(mesh, cells, x)
    cells = [cells.links(i)[0] for i in range(x.shape[0])]
    return f_h.eval(x, cells)

# evaluate the FEM function on the domain
y_h = eval_func(f_h, x)

# create data set of samples from f
#   NOTE this is what I have in reality that I want to represent in the FEM basis
n_data = 5
x_data = sample_domain(n_data)
y_data = f(x_data.T).T

# create function that interpolates data samples
f_data = LinearNDInterpolator(x_data, y_data)
y_data = f_data(x)

# interpolate f_data into the FEM basis
f_data_h = dolfinx.fem.Function(V)
f_data_h.interpolate(lambda x: np.ascontiguousarray(f_data(x[:2].T).T))
y_data_h = eval_func(f_data_h, x)

# reshape into images for plotting
y = y.reshape(n, n, 2)
y_h = y_h.reshape(n, n, 2)
y_data = y_data.reshape(n, n, 2)
y_data_h = y_data_h.reshape(n, n, 2)

# plot the components of y, y_h, y_data, and y_data_h
fig, axes = plt.subplots(2, 4, figsize=(8, 4))

axes[0,0].imshow(y[...,0].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[1,0].imshow(y[...,1].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[0,1].imshow(y_h[...,0].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[1,1].imshow(y_h[...,1].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[0,2].imshow(y_data[...,0].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[1,2].imshow(y_data[...,1].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[0,3].imshow(y_data_h[...,0].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])
axes[1,3].imshow(y_data_h[...,1].real, cmap='Greys', vmin=0, vmax=4, extent=[-1, 1, -1, 1])

axes[0,0].set_ylabel('component 0')
axes[1,0].set_ylabel('component 1')
axes[0,0].set_title('y')
axes[0,1].set_title('y_h')
axes[0,2].set_title('y_data')
axes[0,3].set_title('y_data_h')

fig.tight_layout()

<IPython.core.display.Javascript object>

In [49]:
f_data(np.eye(2))

array([[1.00010203, 1.00010203],
       [1.00010203, 1.00010203]])

## 2D Helmholtz inverse FEM

We want to solve for the elasticity field $\mu: \Omega \to \mathbb{C}$ given the wave field $u: \Omega \to \mathbb{C}^2$, where $\Omega \subset \mathbb{R}^2$.

Strong form:

$$
\begin{align}
    \mu(\mathbf{x}) \nabla^2 u(\mathbf{x}) + \rho \omega^2 u(\mathbf{x}) &= 0
\end{align}
$$

In [62]:
%autoreload

# load the data set
data, _ = mre_pinn.data.load_bioqic_dataset(
    '../data/BIOQIC', 'fem_box', frequency=80, xyz_slice='2D', downsample=False
)
data

Loading ../data/BIOQIC/four_target_phantom.mat
    __header__: <class 'bytes'>
    __version__: <class 'str'>
    __globals__: <class 'list'>
    u_ft: <class 'numpy.ndarray'> (100, 80, 10, 3, 6) complex128
Loading ../data/BIOQIC/fem_box_elastogram.npy
     <class 'numpy.ndarray'> (6, 10, 80, 100) complex128
Loading ../data/BIOQIC/fem_box_regions.npy
     <class 'numpy.ndarray'> (10, 80, 100) int64
Single frequency 2D
<xarray.Dataset>
Dimensions:         (frequency: 1, x: 80, y: 100, component: 2)
Coordinates:
  * frequency       (frequency) float64 80.0
  * x               (x) float64 0.0 0.001 0.002 0.003 ... 0.077 0.078 0.079
  * y               (y) float64 0.0 0.001 0.002 0.003 ... 0.097 0.098 0.099
    z               float64 0.0
  * component       (component) <U1 'z' 'y'
Data variables:
    u               (frequency, x, y, component) complex128 (-4.2190458277627...
    mu              (frequency, x, y) complex128 (3000+502.6548245743669j) .....
    spatial_region  (x, y) int64 

In [63]:
wave_kws = mre_pinn.visual.get_color_kws(data.u)
laplace_kws = mre_pinn.visual.get_color_kws(data.Lu)
mre_pinn.visual.XArrayViewer(data.u, col='part', ax_width=2, **wave_kws)
#mre_pinn.visual.XArrayViewer(data.Lu, col='part', ax_width=2, **laplace_kws)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='domain', options=(('space', 0), ('frequency', 1)), value=0)…

<mre_pinn.visual.XArrayViewer at 0x14e91b57f280>

In [83]:
n_frequency, n_x, n_y, n_component = data.u.shape

x = data.u.field.spatial_points()
u = data.u.field.values()[:,0:1]
print(x.shape, u.shape)

x_min = x.min(axis=0)
x_max = x.max(axis=0)
print(x_min)
print(x_max)

interp = LinearNDInterpolator(points=x, values=u)
u_interp = interp(x)
print(u_interp.shape)

mesh = dolfinx.mesh.create_rectangle(
    comm=MPI.COMM_WORLD,
    points=[x_min, x_max],
    n=[n_x - 1, n_y - 1],
    cell_type=dolfinx.mesh.CellType.triangle,
    diagonal=dolfinx.mesh.DiagonalType.right_left
)
V = dolfinx.fem.FunctionSpace(mesh, ('Lagrange', 1))
func = dolfinx.fem.Function(V)
func.interpolate(lambda x: interp(x[:2].T).T)

x = data.u.field.spatial_points()
u_func = eval_func(func, x)
print(u_func.shape)

# reshape into images for plotting
u = u.reshape(80, 100)
u_interp = u_interp.reshape(80, 100)
u_func = u_func.reshape(80, 100)

fig, axes = plt.subplots(1, 3, figsize=(8, 3))
axes[0].imshow(u.T.real, origin='lower', **wave_kws)
axes[1].imshow(u_interp.T.real, origin='lower', **wave_kws)
axes[2].imshow(u_func.T.real, origin='lower', **wave_kws)
axes[0].set_title('u')
axes[1].set_title('u_interp')
axes[2].set_title('u_func')

(8000, 2) (8000, 1)
[0. 0.]
[0.079 0.099]
(8000, 1)
(8000, 1)


<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'u_func')

In [86]:
# get the domain from the wave field data

n_frequency, n_x, n_y, n_component = data.u.shape
assert n_frequency == 1
#assert n_z == 1

x = data.u.field.spatial_points()
u = data.u.field.values()[:,:1]
print(x.shape, u.shape)

x_min = x.min(axis=0)[:2]
x_max = x.max(axis=0)[:2]
print(x_min)
print(x_max)

# define the mesh and function spaces

mesh = dolfinx.mesh.create_rectangle(
    comm=MPI.COMM_WORLD,
    points=[x_min, x_max],
    n=[n_x - 1, n_y - 1],
    cell_type=dolfinx.mesh.CellType.triangle
)

scalar_func_space = dolfinx.fem.FunctionSpace(mesh, ('Lagrange', 1))
#vector_func_space = dolfinx.fem.VectorFunctionSpace(mesh, ('Lagrange', 1), dim=2)

(8000, 2) (8000, 1)
[0. 0.]
[0.079 0.099]


In [87]:
# identify boundary nodes

def on_boundary(x):
    x = x[:2] # assume only 2D
    x_loc = x.mean(axis=1, keepdims=True)
    x_scale = (x.max(axis=1, keepdims=True) - x.min(axis=1, keepdims=True)) / 2
    x = (x - x_loc) / x_scale
    return np.isclose(np.linalg.norm(x, np.inf, axis=0), 1)

y = on_boundary(x.T)
c = y.astype(float) / 2 + 0.25

fig, ax = plt.subplots(figsize=(4, 5))
ax.scatter(*x.T[:2], c=c, s=2, cmap='Greys', vmin=0, vmax=1)

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x14e91a998c40>

In [89]:
# setup physical problem and boundary conditions

rho = 1000
omega = 2 * np.pi * 80

u_interp = LinearNDInterpolator(points=x, values=u)
u_func = dolfinx.fem.Function(scalar_func_space)
u_func.interpolate(lambda x: u_interp(x[:2].T).T)

mu_func = dolfinx.fem.Function(scalar_func_space)
mu_func.interpolate(lambda x: 3e3 * np.ones_like(x[0]))

boundary_dofs = dolfinx.fem.locate_dofs_geometrical(scalar_func_space, on_boundary)
mu_bc = dolfinx.fem.dirichletbc(mu_func, dofs=boundary_dofs)

Variational form:

$$
\begin{align}
    \int_\Omega \mu \nabla u \cdot \nabla \bar{v} d\Omega &= \int_\Omega \rho \omega^2 u \bar{v} d\Omega
\end{align}
$$

In [96]:
%%time

# solve variational Helmholtz problem
trial_func = ufl.TrialFunction(scalar_func_space)
test_func = ufl.TestFunction(scalar_func_space)

f_func = dolfinx.fem.Function(scalar_func_space) # source is zero
f_func.interpolate(lambda x: 0 * x[:1])

Ax = trial_func * ufl.inner(ufl.grad(u_func), ufl.grad(test_func)) * ufl.dx
b = rho * omega**2 * ufl.inner(u_func, test_func) * ufl.dx

problem = dolfinx.fem.petsc.LinearProblem(
    Ax, b, bcs=[], #petsc_options={"ksp_type": "lsqr", "pc_type": "none"}
)
mu_pred_func = problem.solve()

CPU times: user 214 ms, sys: 16.4 ms, total: 230 ms
Wall time: 309 ms


In [97]:
# evaluate function on mesh

x = data.u.field.spatial_points()
mu_pred = eval_func(mu_pred_func, x)
mu_pred = mu_pred.reshape(*data.mu.shape)
mu_pred.shape, mu_pred.dtype

((1, 80, 100), dtype('complex128'))

In [98]:
# visualize the elastogram

data['mu_pred'] = mre_pinn.utils.as_xarray(mu_pred, like=data.mu)

elast_kws = mre_pinn.visual.get_color_kws(data.mu_pred)

mre_pinn.visual.XArrayViewer(data.mu_pred, ax_width=2, **elast_kws)

<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

<mre_pinn.visual.XArrayViewer at 0x14e91ae00ac0>