In [1]:
# Simple, corrected 2D cross-section ultrasound sim
# deps: pip install nibabel jwave scipy

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import jit
from scipy.signal import hilbert

from jwave.geometry import Domain, Medium, TimeAxis, Sensors, Sources
from jwave.acoustics import simulate_wave_propagation

In [11]:
methods = {#"NN0p5mm":         "reviewUpsampledSegments/C4NNUpsampleddx5.nii.gz", 
           #"Linear0p5mm":     "reviewUpsampledSegments/C4LinearUpsampleddx5.nii.gz", 
           "Mesh0p5mm":       "reviewUpsampledSegments/C4MeshUpsampleddx5.nii.gz"}
           #"NN0p2mm":         "reviewUpsampledSegments/C4NNUpsampleddx2.nii.gz", 
           #"Linear0p2mm":     "reviewUpsampledSegments/C4LinearUpsampleddx2.nii.gz", 
           #"Mesh0p2mm":       "reviewUpsampledSegments/C4MeshUpsampleddx2.nii.gz",
           #"GroundTruth0p2mm":"UpsampledSegments/newC4HalfGroundTruth.nii.gz"}
# medium setting
c0, rho0 = 1500.0, 1000.0# m/s, kg/m^3
c_bone,  rho_bone = 2200.0, 2600.0
voxel_size = 1e-4
SLICE_AXIS   = "x"# "z", "y", or "x"

# source & numerics
fc         = 1.0e6# Hz (will be capped by grid spacing)
cfl        = 0.3

FOCUS_DEPTH_MM = 60 # depth from the source edge into the slice
n_burst_onepulse = 3# how many bursts for short pulse simulation
dz, dy, dx = voxel_size, voxel_size, voxel_size

In [3]:
def loadimg(m, grid):
    NIFTI_PATH   = methods[m + grid]  # 3D mask (Z,Y,X); vertebra ~1
    #print(NIFTI_PATH)
    nii = nib.load(NIFTI_PATH)
    vol = nii.get_fdata(dtype=np.float32)              # (Z,Y,X)
    return vol

def setmedium(m, grid, SLICE_INDEX):
    global Ny2, Nx2, ys, xs

    def pick_idx(axis_name, idx):
        n = {"z": Nz, "y": Ny, "x": Nx}[axis_name.lower()]
        return n//2 if idx == "mid" else int(idx)

    
    vol = loadimg(m, grid)
    Nz, Ny, Nx = vol.shape
    ax  = SLICE_AXIS.lower()
    idx = pick_idx(ax, SLICE_INDEX)
    if ax == "z":
        img2d = np.ascontiguousarray(vol[idx, :, :].astype(np.float32))   # (Y,X)
        dy2, dx2 = dy, dx
    elif ax == "y":
        img2d = np.ascontiguousarray(vol[:, idx, :].astype(np.float32))   # (Z,X)
        dy2, dx2 = dz, dx
    else:  # "x"
        img2d = np.ascontiguousarray(vol[:, :, idx].astype(np.float32))   # (Z,Y)
        dy2, dx2 = dz, dy
    
    Ny2, Nx2 = img2d.shape
    mask2d = np.int32(img2d)
    ROI = mask2d
    ys, xs = np.where(ROI>=0)
    Ns = ys.size
    
    #sensors = Sensors(positions=(jnp.asarray(ys), jnp.asarray(tuple(xs))))
    sensors_pack = (jnp.asarray(ys, dtype=jnp.int32),
                jnp.asarray(xs, dtype=jnp.int32))
    c_field = mask2d 
    c_field = mask2d * (c_bone - c0) + c0
    rho_field = mask2d
    rho_field = mask2d * (rho_bone - rho0) + rho0
    
    dom2d   = Domain((Ny2, Nx2), (dy2, dx2))    
    c_field_reshaped   = jnp.asarray(c_field,   dtype=jnp.float32)[..., None]   # (Ny, Nx, 1)
    rho_field_reshaped = jnp.asarray(rho_field, dtype=jnp.float32)[..., None]   # (Ny, Nx, 1)
    medium2 = Medium(domain=dom2d, sound_speed=c_field_reshaped, density=rho_field_reshaped)

    return sensors_pack, dom2d, medium2

In [4]:
def setsource(medium2, dom2d, cfl):
    global Nt, dt
    def plane_source():
        signals_np = np.zeros((Ne2, Nt), np.float32)
        for k in range(n_burst):
            start = k * len_burst
            stop  = start + len_burst
            stop = min(stop, Nt)
        
            if start >= Nt:         # outside time window, skip
                break
            seg  = np.asarray(s0[:(stop - start)], dtype=np.float32)
            signals_np[:, start:stop] += seg  # same burst on all elements (plane wave)
        return signals_np
    
    def focused_source():
        signals_np = np.zeros((Ne2, Nt), np.float32)  # (elements, time)
        
        # focus pixel (2D)
        focus_y_pix = int(round((FOCUS_DEPTH_MM*1e-3) / dy))
        focus_x_pix = Nx2 * 2 // 5
        
        yy = src_y * dy
        xx = src_x * dx
        fy = float(focus_y_pix) * dy
        fx = float(focus_x_pix) * dx
        
        delay2d   = np.sqrt((yy - fy)**2 + (xx - fx)**2)
        tof   = delay2d / c0
        tof = tof.max() - tof
        sd = np.round(tof / dt).astype(int)
        
        for i, delay_samp in enumerate(sd):           # one delay per element
            if delay_samp >= Nt:
                continue
            for k in range(n_burst):
                start = delay_samp + k * len_burst
                if start >= Nt:
                    break
                stop = min(start + len_burst, Nt)
                if stop > start:
                    seg = s0[:(stop - start)]
                    signals_np[i, start:stop] += seg
        
        peak = np.max(np.abs(signals_np))
        if peak > 0:
            signals_np /= peak
        return signals_np

    src_y_offset = 20+3            # put source a few cells inside the boundary
    src_x  = np.arange(Nx2, dtype=int)
    src_y = np.full_like(src_x, src_y_offset, dtype=int)
    Ne2   = int(src_x.size)
    
    # Time axis and pulse-train parameters
    depth_time = (Ny2 * dy) / c0
    burst_time = (1 / fc)
    
    if simulationtype == "onepulse": # onepulse, continuous
        n_burst = n_burst_onepulse
        t_end = depth_time
    
    elif simulationtype == "continuous":
        n_burst = int(depth_time*1.5/burst_time) + 1
        t_end = depth_time * 1.5
        
    t_axis = TimeAxis.from_medium(medium2, cfl=cfl, t_end=t_end)
    Nt = int(t_axis.Nt); dt = float(t_axis.dt)
    t  = jnp.arange(Nt) * dt
    
    tau   = jnp.arange(int(jnp.ceil(burst_time/dt))) * dt
    s0    = jnp.sin(2*jnp.pi*fc*tau)
    s0 = np.asarray(s0, np.float32)
    
    len_burst  = int(s0.size)

    if sourcetype == "planar":
        signals_np = plane_source()
    elif sourcetype == "focused":
        signals_np = focused_source()
    
    signals_2d = jnp.asarray(signals_np)
    
    srcs = Sources(
        positions=(jnp.asarray(src_y.tolist()), jnp.asarray(src_x.tolist())),
        signals=signals_2d,
        dt=dt,
        domain=dom2d
    )
    return t_axis, srcs

In [5]:
#plot the signal (if want)
"""
# pick one element (e.g. the first transducer element)
sig0 = signals_np[0]

time_us = np.arange(sig0.size) * dt * 1e6

plt.figure(figsize=(10,4))
plt.plot(time_us, sig0, lw=1)
plt.xlabel("Time [µs]")
plt.ylabel("Amplitude [a.u.]")
plt.title(m + grid + sourcetype + simulationtype)
plt.grid(True)
plt.show()
"""

'\n# pick one element (e.g. the first transducer element)\nsig0 = signals_np[0]\n\ntime_us = np.arange(sig0.size) * dt * 1e6\n\nplt.figure(figsize=(10,4))\nplt.plot(time_us, sig0, lw=1)\nplt.xlabel("Time [µs]")\nplt.ylabel("Amplitude [a.u.]")\nplt.title(m + grid + sourcetype + simulationtype)\nplt.grid(True)\nplt.show()\n'

In [6]:
def run(medium2, t_axis, srcs, sensors_pack):
    sensors = Sensors(positions=(sensors_pack[0], sensors_pack[1]))
    p_wave = simulate_wave_propagation(medium2, t_axis, sources=srcs, sensors=sensors)
    #print(p_wave.shape)
    return p_wave

In [7]:
def to_image(vec, fill=np.nan):
    img = np.full((Ny2, Nx2), fill, dtype=np.float32)
    img[ys, xs] = vec.astype(np.float32)
    return img

In [8]:
def loadsliceROI(s):
    ROI_path = "ROIs/erode_C4ROI3D.npy"
    ROI3D = np.load(ROI_path) #canal shape
    
    return ROI3D[:, :, s]

In [9]:
def savecrossection(pathbase):
    #save cross section at 25%, 50%, 75% and end time axis
    simlength, _ = A.shape
    
    frame_img = to_image(A[int(simlength/4)], fill=np.nan)
    np.save(pathbase + "_t25.npy", frame_img)
    
    frame_img = to_image(A[int(simlength/4*2)], fill=np.nan)
    np.save(pathbase + "_t50.npy", frame_img)
    
    frame_img = to_image(A[int(simlength/4*3)], fill=np.nan)
    np.save(pathbase + "_t75.npy", frame_img)
    
    frame_img = to_image(A[-1], fill=np.nan)
    np.save(pathbase + "_t100.npy", frame_img)

def savepeakpressure(pathbase):
    #save peak pressure field from short pulse simulation
    max_pressure = np.max(A, axis=0)
    img = np.full((Ny2, Nx2), np.nan, dtype=np.float32)
    img[ys, xs] = max_pressure
    
    np.save(pathbase + "_PeakPressure", img)

def savesteadyfrequency(pathbase):
    #save steady state wave amplitude for continuous pulse simulation
    fs = 1/dt
    record_cycles = 3
    samples_per_cycle = int(fs / fc)
    record_length = record_cycles * samples_per_cycle
    last_segment = A[-record_length:]  # take last N samples
    
    fft_vals = np.fft.rfft(last_segment, axis = 0)
    fft_freqs = np.fft.rfftfreq(len(last_segment), 1/fs)
    magnitude_spectrum = np.abs(fft_vals)
    
    fundmental_freq = np.argmin(np.abs(fft_freqs))
    mag_f0 = magnitude_spectrum[fundmental_freq]
    steadywave = np.full((Ny2, Nx2), np.nan, dtype=np.float32)
    steadywave[ys, xs] = mag_f0
    np.save(pathbase + "_Steadywave", steadywave)

In [12]:
#all ROI slices simulation
batches = [["focused", "onepulse"], 
           ["planar", "onepulse" ],
           ["focused", "continuous"], 
           ["planar", "continuous"]]
SLICE_INDEX_LIST  = [i for i in range(86,121)]

i = 0
for b in batches:
    sourcetype = b[0]       # planar, focused
    simulationtype = b[1]   # onepulse, continuous
    for dname, _ in methods.items():
        for SLICE_INDEX in SLICE_INDEX_LIST:
            grid = dname[-5:]
            m = dname[:-5]
            sensors_pack, dom2d, medium2 = setmedium(m, grid, SLICE_INDEX)
            t_axis, srcs = setsource(medium2, dom2d, cfl)

            A = run(medium2, t_axis, srcs, sensors_pack)                           # (Nt, Ns) or (Nt, Ns, 1)
            A = np.asarray(A)
            A = A[..., 0]
            
            pathbase = "review2DSimulation/" + m + "_slice"+str(SLICE_INDEX)+ sourcetype + grid + simulationtype

            #savecrossection(pathbase)
            if simulationtype == "onepulse":
                savepeakpressure(pathbase)
            elif simulationtype == "continuous":
                savesteadyfrequency(pathbase)
                
            i = i + 1
            print("Saving simulation number", i, ". Setting", SLICE_INDEX, m, grid, sourcetype, simulationtype)
stophere #just to stop the notebook running

Saving simulation number 1 . Setting 86 Mesh 0p5mm focused onepulse
Saving simulation number 2 . Setting 87 Mesh 0p5mm focused onepulse
Saving simulation number 3 . Setting 88 Mesh 0p5mm focused onepulse
Saving simulation number 4 . Setting 89 Mesh 0p5mm focused onepulse
Saving simulation number 5 . Setting 90 Mesh 0p5mm focused onepulse
Saving simulation number 6 . Setting 91 Mesh 0p5mm focused onepulse
Saving simulation number 7 . Setting 92 Mesh 0p5mm focused onepulse
Saving simulation number 8 . Setting 93 Mesh 0p5mm focused onepulse
Saving simulation number 9 . Setting 94 Mesh 0p5mm focused onepulse
Saving simulation number 10 . Setting 95 Mesh 0p5mm focused onepulse
Saving simulation number 11 . Setting 96 Mesh 0p5mm focused onepulse
Saving simulation number 12 . Setting 97 Mesh 0p5mm focused onepulse
Saving simulation number 13 . Setting 98 Mesh 0p5mm focused onepulse
Saving simulation number 14 . Setting 99 Mesh 0p5mm focused onepulse
Saving simulation number 15 . Setting 100 M

NameError: name 'stophere' is not defined

In [None]:
#convergence test
batches = ["focused", "onepulse"]
SLICE_INDEX  = 90
cfls = [0.5, 0.4, 0.3, 0.2, 0.1, 0.05]
dname = "GroundTruth0p5mm"

sourcetype = batches[0]       # planar, focused
simulationtype = batches[1]   # onepulse, continuous

for cfl in cfls:
    grid = dname[-5:]
    m = dname[:-5]
    sensors_pack, dom2d, medium2 = setmedium(m, grid, SLICE_INDEX)
    t_axis, srcs = setsource(medium2, dom2d, cfl)

    A = run(medium2, t_axis, srcs, sensors_pack)                           # (Nt, Ns) or (Nt, Ns, 1)
    A = np.asarray(A)
    A = A[..., 0]
    
    pathbase = "2DSimulation_allslice/" + m + "_slice"+str(cfl)+ sourcetype + grid + simulationtype

    if simulationtype == "onepulse":
        savepeakpressure(pathbase)
    elif simulationtype == "continuous":
        savesteadyfrequency(pathbase)
        
    print("Saving simulation number", ". Setting", cfl, SLICE_INDEX, m, grid, sourcetype, simulationtype)

stophere

In [None]:
# interact plot of a simulation (if want)
from ipywidgets import interact, FloatSlider
def show_snapshot(t_us):
    t_idx = int(np.clip((t_us*1e-6)/dt, 0, Nt-1))
    frame_img = to_image(A[t_idx], fill=np.nan)
    vmax = np.nanmax(np.abs(frame_img)) or 1.0
    plt.figure(figsize=(6,5))
    im = plt.imshow(frame_img, origin="upper", cmap="RdBu_r", vmin=-vmax, vmax=vmax)
    plt.title(f"Pressure snapshot  t = {t_idx*dt*1e6:.1f} µs")
    plt.axis("off"); plt.colorbar(im, fraction=0.046, pad=0.04); plt.show()

interact(
    show_snapshot,
    t_us=FloatSlider(min=0, max=Nt*dt*1e6, step=dt*1e6, value=0.35*Nt*dt*1e6)
)
