In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jax
from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

import sys
sys.path.append('../')

# Simulate data

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = 64
element_pitch = 2.95e-4 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 2e6 # frequency of the transducer [Hz]
transducer_magnitude = 1e6 # magnitude of the transducer [Pa]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
# define spatial parameters
N = np.array([128, 128]).astype(int) # grid size [grid points]
dx = np.array([element_pitch, element_pitch]) # grid spacing [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(transducer_x_start, transducer_x_start + nelements - 1, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)
element_positions

In [None]:
from utils.jwave_utils import get_domain, get_point_medium, get_homogeneous_medium

# define jwave medium
c0 = 1500 # speed of sound [m/s]
medium_params = {
    'c0': c0,  # speed of sound [m/s]
    'rho0': 1000,  # density [kg/m^3]
    'background_mean': 1,  # mean of the background noise
    'pml_size': pml[0]  # size of the perfectly matched layer [grid points]
}

domain = get_domain(N, dx)
speed_homogenous, density_homogenous = get_homogeneous_medium(domain, **medium_params, background_std=0, background_seed=29)

scatterer_positions = np.array([[domain.N[0]//2, domain.N[1]//2 + 15],
                                [domain.N[0]//2, domain.N[1]//2],
                                [domain.N[0]//2+15, domain.N[1]//2 - 15]], dtype=int)
speed, density = get_point_medium(domain, scatterer_positions, **medium_params, background_std = 0.000, scatterer_radius=1, scatterer_contrast=1.1, background_seed=28)
# speed[:,element_positions[1][0]:] = 1500
# density[:,element_positions[1][0]:] = 1000

ext = [0, N[0]*dx[0], N[1]*dx[1], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.imshow(speed, cmap='gray', extent=ext)
plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave.geometry import TimeAxis
from jwave.geometry import Medium
from utils.jwave_utils import get_plane_wave_excitation

angle = 0 * np.pi / 180
time_axis = TimeAxis.from_medium(Medium(domain, speed, density, pml_size=pml[0]), cfl=0.3)
sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, element_pitch, element_positions, angle=angle)

plt.plot(sources.signals[10])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
from utils.jwave_utils import get_data

# simulate data using jwave
pressure, data = get_data(speed, density, domain, time_axis, sources, element_positions)
_, data_homogenous = get_data(speed_homogenous, density_homogenous, domain, time_axis, sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 200
show_field(pressure[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

In [None]:
output_data = data-data_homogenous
plt.imshow(output_data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

# Reconstruction

## Single angle

### Naive

In [None]:
from utils.beamforming_utils import get_receive_beamforming
signal_delay = (element_pitch * np.sin(angle) / c0) / time_axis.dt 
res = get_receive_beamforming(domain, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)

In [None]:
plt.imshow(res.T, cmap='seismic')
plt.colorbar()
plt.gca().invert_yaxis()
plt.show()

In [None]:
# from kwave.utils.filters import gaussian_filter
# from kwave.reconstruction.beamform import envelope_detection

# def postprocess_result(orig_res):
#     result = np.copy(orig_res)
#     for i in range(result.shape[0]):
#         result[i, :] = gaussian_filter(result[i, :], 1/dx[0], transducer_frequency, 100.0)
#     for i in range(result.shape[0]):
#         result[i, :] = envelope_detection(result[i, :])
#     return np.flipud(result).T

# bmode=postprocess_result(res)
# plt.imshow(bmode, cmap='seismic', interpolation='nearest')
# plt.colorbar()
# plt.gca().invert_yaxis()
# plt.show()

### ntk

In [None]:
from imaging.demodulate import demodulate_rf_to_iq
freq_sampling = 1/time_axis.dt
iq_signals, freq_carrier = demodulate_rf_to_iq(output_data, freq_sampling, freq_carrier=transducer_frequency)

In [None]:
Nz = domain.N[1] - transducer_depth
dx0 = domain.dx[0]
x = np.linspace(-(domain.N[0]//2)*dx0, (domain.N[0]//2)*dx0, domain.N[0])
z = np.linspace(0, Nz*dx0, Nz)
X, Z = np.meshgrid(x, z)

In [None]:
from imaging.beamform import beamform_delay_and_sum
beamformed_signal = beamform_delay_and_sum(iq_signals, X, Z, freq_sampling, freq_carrier, pitch=element_pitch, tx_delays=np.zeros(nelements))

In [None]:
beamformed_signal_2 = beamformed_signal.copy()
# beamformed_signal_2[:25, :] = 0

plt.imshow(np.abs(beamformed_signal_2), cmap='seismic')
plt.colorbar()
plt.show()

## Multiple angles

In [None]:
angles = np.linspace(-30*np.pi/180, 30*np.pi/180, 20)
results_naive = []
results_ntk = []
for angle in angles:
    print(f"Plane wave angle: {angle * 180 / np.pi:.2f} degrees")
    sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, element_pitch, element_positions, angle=angle)
    _, data = get_data(speed, density, domain, time_axis, sources, element_positions)
    _, data_homogenous = get_data(speed_homogenous, density_homogenous, domain, time_axis, sources, element_positions)
    output_data = data-data_homogenous
    
    # naive beamforming
    # signal_delay = (element_pitch * np.sin(angle) / c0) / time_axis.dt
    # naive = get_receive_beamforming(domain, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)
    # results_naive.append(naive)

    # NTK beamforming
    signal_delay = (element_pitch * np.sin(angle) / c0)
    signal_delays = []
    for i in range(nelements):
        if angle < 0:
            signal_delays.append(i * signal_delay * time_axis.dt)
        elif angle > 0:
            signal_delays.append((i-nelements) * signal_delay * time_axis.dt)
        else:
            signal_delays.append(0)
    # shift = np.abs(int((transducer_extent * np.tan(angle) / domain.dx[0])))

    iq_signals, freq_carrier = demodulate_rf_to_iq(output_data, freq_sampling, freq_carrier=transducer_frequency)
    ntk = beamform_delay_and_sum(iq_signals, X, Z, freq_sampling, freq_carrier, pitch=element_pitch, tx_delays=np.array(signal_delays))
    ntk = np.flipud(ntk)
    # ntk = np.roll(ntk, shift, axis=0)
    results_ntk.append(ntk)

In [None]:
compounded_naive = np.sum(results_naive, axis=0)
compounded_ntk = np.sum(results_ntk, axis=0)

plt.imshow(np.abs(compounded_ntk), cmap='seismic', interpolation='nearest')
# plt.imshow(np.abs(results_ntk[0]), cmap='seismic', interpolation='nearest')
plt.colorbar()
plt.gca().invert_yaxis()
plt.show()

## Gradient

In [None]:
from jax import value_and_grad
from utils.jwave_utils import get_data_only
from utils.solver_utils import linear_loss, nonlinear_loss

params = speed_homogenous

# compute first linear gradient
# J = jax.jacrev(get_data_only, argnums=0)(jnp.array(speed), density_homogenous, domain, time_axis, sources, element_positions)
# linear_val_and_grad = value_and_grad(linear_loss, argnums=0)
# linear_loss, linear_gradient = linear_val_and_grad(params, J, output_data)

# compute first nonlinear gradient
nonlinear_val_and_grad = value_and_grad(nonlinear_loss, argnums=0)
nonlinear_loss, nonlinear_gradient = nonlinear_val_and_grad(params, data, density_homogenous, domain, time_axis, sources, element_positions)
# nonlinear_gradient = nonlinear_gradient.at[:, 90:].set(0) # apply mask
print(f"Nonlinear loss: {nonlinear_loss}")

In [None]:
# Viualize
plt.figure(figsize=(8, 6))
plt.imshow(nonlinear_gradient.T, cmap='seismic')
plt.title("First gradient")
plt.xlabel('x [gridpoints]')
plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
plt.colorbar(shrink=0.55)
plt.show()

In [None]:
from jax.example_libraries import optimizers
from tqdm import tqdm
from jwave.signal_processing import smooth

losshistory = []
reconstructions = []
num_steps = 100

# Define optimizer
# init_fun, update_fun, get_params = optimizers.adam(1, 0.9, 0.9)
init_fun, update_fun, get_params = optimizers.sgd(1)
opt_state = init_fun(params)

# Define and compile the update function
@jit
def update(opt_state, k):
    v = get_params(opt_state)
    lossval, gradient = nonlinear_val_and_grad(v, data, density_homogenous, domain, time_axis, sources, element_positions)
    # gradient = smooth(gradient)
    # gradient = gradient.at[:, 90:].set(0)
    gradient = gradient / jnp.max(jnp.abs(gradient))
    return lossval, update_fun(k, gradient, opt_state)

# Main loop
pbar = tqdm(range(num_steps))
for k in pbar:
    lossval, opt_state = update(opt_state, k)

    ## For logging
    new_params = get_params(opt_state)
    reconstructions.append(new_params)
    losshistory.append(lossval)
    pbar.set_description("Loss: {}".format(lossval))

In [None]:
# Viualize
plt.figure(figsize=(8, 6))
plt.imshow(reconstructions[-1].T, cmap='seismic')
plt.xlabel('x [gridpoints]')
plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
plt.colorbar(shrink=0.55)
plt.show()