In [2]:
from __future__ import annotations
import numpy as np
import jax.numpy as jnp
from jax import jit
from typing import Tuple, Optional, Dict
import nibabel as nib
import matplotlib.pyplot as plt

from jwave.geometry import Domain, Medium, TimeAxis, Sensors, Sources
from jwave.acoustics import simulate_wave_propagation

In [3]:
#path and constants
methods = {"Mesh0p5mm":"UpsampledSegments/C4MeshUpsampled0.5mm.nii.gz",
           "GroundTruth0p2mm":"UpsampledSegments/newC4HalfGroundTruth.nii.gz",
           "GroundTruth0p5mm":"UpsampledSegments/newC4HalfGroundTruth.nii.gz",
           "NN0p5mm":"UpsampledSegments/C4NNUpsampled0.5mm.nii.gz",
           "NN0p2mm":"UpsampledSegments/C4NNUpsampled.nii.gz",
           "Linear0p2mm":"UpsampledSegments/C4LinearUpsampled.nii.gz", 
           "Mesh0p2mm":"UpsampledSegments/C4MeshUpsampled.nii.gz",
           "Linear0p5mm":"UpsampledSegments/C4LinearUpsampled0.5mm.nii.gz", }

ROI_path = "ROIs/erode_C4ROI3D.npy"

ROI = np.load(ROI_path) #canal shape
zs, ys, xs = np.where(ROI>0)
Ns = ys.size
#print("sensors number:" + str(Ns))
sensors = Sensors(positions=(tuple(map(int, zs)), tuple(map(int, ys)), tuple(map(int, xs))))

c0, rho0 = 1500.0, 1000.0   # m/s, kg/m^3
c_bone,  rho_bone = 2200.0, 2600.0
voxel_size = 1e-4

fc         = 1.0e6       # Hz (will be capped by grid spacing)
cfl        = 0.3
n_burst_onepulse = 3             # how many bursts
src_z_offset = 20+3      # put source a few cells inside the boundary
FOCUS_DEPTH_MM = 60       # depth from the source edge into the slice

In [4]:
def setmedium():
    NIFTI_PATH   = methods[m+grid] 
    print(NIFTI_PATH)
    nii = nib.load(NIFTI_PATH)
    vol = nii.get_fdata()
    dz, dy, dx = voxel_size, voxel_size, voxel_size
    Nz, Ny, Nx = vol.shape
    
    dz, dy, dx = float(dz), float(dy), float(dx)
    Nz, Ny, Nx = int(Nz), int(Ny), int(Nx)
    
    mask3d = np.int32(vol)
    #ROI = mask3d

    # Domain, medium, time settings
    c_field = mask3d 
    c_field = mask3d * (c_bone - c0) + c0
    rho_field = mask3d
    rho_field = mask3d * (rho_bone - rho0) + rho0
    
    dom3d   = Domain((Nz, Ny, Nx), (dz, dy, dx))
    
    c_field_reshaped   = jnp.asarray(c_field,   dtype=jnp.float32)[..., None]   # (Ny, Nx, 1)
    rho_field_reshaped = jnp.asarray(rho_field, dtype=jnp.float32)[..., None]   # (Ny, Nx, 1)
    medium3 = Medium(domain=dom3d, sound_speed=c_field_reshaped, density=rho_field_reshaped)

    return dom3d, medium3, (dz, dy, dx, Nz, Ny, Nx)

In [5]:
def setsource(medium, dom, dims):
    dz, dy, dx, Nz, Ny, Nx = dims
    
    # source & numerics
    src_x, src_y = np.meshgrid(np.arange(Nx), np.arange(Ny), indexing="ij")
    src_x, src_y = src_x.flatten(), src_y.flatten()
    src_z = np.full_like(src_y, src_z_offset)
    Ne   = int(src_z.size)
    #print(Ne)
    
    # Time axis and pulse-train parameters
    depth_time = (Nz * dz) / c0
    burst_time = (1 / fc)
    
    if simulationtype == "onepulse": # onepulse, continuous
        n_burst = n_burst_onepulse
        t_end = depth_time
    
    elif simulationtype == "continuous":
        n_burst = int(depth_time*1.5/burst_time) + 1
        t_end = depth_time * 1.5
    
    
    t_axis = TimeAxis.from_medium(medium, cfl=cfl, t_end=t_end)
    Nt = int(t_axis.Nt); dt = float(t_axis.dt)
    t  = jnp.arange(Nt) * dt
    
    tau   = jnp.arange(int(jnp.ceil(burst_time/dt))) * dt
    s0    = jnp.sin(2*jnp.pi*fc*tau)
    s0 = np.asarray(s0, np.float32)
    
    len_burst  = int(s0.size)

    #make plane wave
    def plane_source():
        signals_np = np.zeros((Ne, Nt), np.float32)
        for burst_i in range(n_burst):
            start = burst_i * len_burst
            stop  = start + len_burst
            stop = min(stop, Nt)
            if start >= Nt:         # outside time window, skip
                break
            
            seg  = np.asarray(s0[:(stop - start)], dtype=np.float32)
            signals_np[:, start:stop] += seg  # same burst on all elements (plane wave)
        return signals_np
    
    def focused_source():
        # focus pixel (3D)
        focus_z_pix = int(round((FOCUS_DEPTH_MM*1e-3) / voxel_size))
        focus_y_pix = Ny * 2 // 5
        focus_x_pix = 100
        #make focused wave        
        signals_np = np.zeros((Ne, Nt), np.float32)
        
        zz = src_z * dz
        yy = src_y * dy
        xx = src_x * dx
        fz = float(focus_z_pix) * dz
        fy = float(focus_y_pix) * dy
        fx = float(focus_x_pix) * dx
        
        delay3d   = np.sqrt((zz - fz)**2 + (yy - fy)**2 + (xx - fx)**2)
        tof   = delay3d / c0
        tof = tof.max() - tof
        
        # convert to integer sample delays
        sd = np.round(tof / dt).astype(int)
        
        for i, delay_samp in enumerate(sd):           # one delay per element
            if delay_samp >= Nt:
                continue
            for k in range(n_burst):
                start = delay_samp + k * len_burst
                if start >= Nt:
                    break
                stop = min(start + len_burst, Nt)
                if stop > start:
                    seg = s0[:(stop - start)]
                    signals_np[i, start:stop] += seg
        
        # Normalize so multiple bursts donâ€™t clip
        peak = np.max(np.abs(signals_np))
        if peak > 0:
            signals_np /= peak
            
        return signals_np
    
    if sourcetype == "planar":
        signals_np = plane_source()
    elif sourcetype == "focused":
        signals_np = focused_source()
    
    signals_3d = jnp.asarray(signals_np)
    
    srcs = Sources(
        positions=(src_z.tolist(), src_y.tolist(), src_x.tolist()),
        signals=signals_3d,
        dt=dt,
        domain=dom
    )

    #plt.figure(figsize=(10,4))
    #plt.plot(signals_np[int(105877/2)], lw=1)
    #plt.show()

    return  t_axis, srcs, sensors


In [6]:
def run(medium, t_axis, srcs, sensors):
    return simulate_wave_propagation(medium, t_axis, sources=srcs, sensors=sensors)

In [7]:
# batch run simulation
batches = [["focused", "onepulse"], 
           ["planar", "onepulse"],
           ["focused", "continuous"], 
           ["planar", "continuous"]]

simtorun = 2 #1~32, including continuous pulse

i = 0

for b in batches:
    sourcetype = "focused" # b[0]       # planar, focused
    simulationtype = "continuous" # b[1]   # onepulse, continuous
    
    for dname, _ in methods.items():
        grid = dname[-5:]
        m = dname[:-5]
        print("Setting up simulation number", i, ". Setting", m, grid, sourcetype, simulationtype)
    
        dom, medium, dims = setmedium()
        t_axis, srcs, sensors = setsource(medium, dom, dims)
        
        print("running simulation:", m, grid, sourcetype, simulationtype)
        A = run(medium, t_axis, srcs, sensors) # (Nt, Ns) or (Nt, Ns, 1)
        A = np.asarray(A)
        A = A[..., 0]

        if simulationtype == "onepulse":
            max_pressure = np.max(A, axis=0)
            np.save("3Dsimulation_" + m + "_" + grid + "_" + sourcetype + "_" + simulationtype, max_pressure)
            #R = np.full(ROI.shape, np.nan, dtype=float)
            #R[ROI>0] = max_pressure
            
        elif simulationtype == "continuous":
            fs = 1/(float(t_axis.dt))       # Sampling frequency (Hz)
            
            record_cycles = 3  # number of cycles to analyze
            samples_per_cycle = int(fs / fc)
            record_length = record_cycles * samples_per_cycle
            last_segment = A[-record_length:]  # take last N samples
            
            fft_vals = np.fft.rfft(last_segment, axis = 0)
            fft_freqs = np.fft.rfftfreq(len(last_segment), 1/fs)
            
            magnitude_spectrum = np.abs(fft_vals)
            
            fundmental_freq = np.argmin(np.abs(fft_freqs))
            mag_f0 = magnitude_spectrum[fundmental_freq]          
            
            np.save("3Dsimulation_" + m + "_" + grid + "_" + sourcetype + "_" + simulationtype, mag_f0)

        print("Saving simulation number", i, ". Setting", m, grid, sourcetype, simulationtype)

        i = i + 1
        if i >= simtorun:
            stop = True
            break            
            
    if stop:
        break


Setting up simulation number 0 . Setting Mesh 0p5mm focused continuous
UpsampledSegments/C4MeshUpsampled0.5mm.nii.gz
running simulation: Mesh 0p5mm focused continuous
Saving simulation number 0 . Setting Mesh 0p5mm focused continuous
Setting up simulation number 1 . Setting GroundTruth 0p2mm focused continuous
UpsampledSegments/newC4HalfGroundTruth.nii.gz
running simulation: GroundTruth 0p2mm focused continuous
Saving simulation number 1 . Setting GroundTruth 0p2mm focused continuous
