In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

import sys
sys.path.append('../')

# Load data

In [None]:
import scipy
# mat = scipy.io.loadmat('../data/1_48_flash_deab_lens_plastic_25.mat')
# mat = scipy.io.loadmat('../data/1_48_flash_deab_lens_25.mat')
mat = scipy.io.loadmat('../data/1_48_flash_deab_25.mat')

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = mat['Trans']['numelements'][0,0][0,0] # number of transducer elements
element_pitch = mat['Trans']['spacingMm'][0,0][0,0]*1e-3 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = mat['Trans']['frequency'][0,0][0,0] * 1e6 # frequency of the transducer [Hz]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
apertureES = np.squeeze(mat['Trans']['HVMux'][0,0]['ApertureES'][0,0])
apertureES = (apertureES[apertureES != 0] - 1)
apertureES

In [None]:
dt_rf = 1/(transducer_frequency * mat['Receive']['samplesPerWave'][0,0][0,0])
raw_rf_data = mat['RcvData'][0,0]
raw_rf_data = raw_rf_data[:,apertureES,:]
raw_rf_data.shape

# Setup

In [None]:
from jwave_utils import get_domain

# define spatial parameters
N = np.array([256, 256]).astype(int) # grid size [grid points]
dx = np.array([element_pitch, element_pitch]) # grid spacing [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

domain = get_domain(N, dx)

In [None]:
# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(transducer_x_start, transducer_x_start + nelements - 1, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)

# filter for active transducers
nactive_elements = 48
start_elem_idx = 35
# element_positions = element_positions[:,:48]
# element_positions = element_positions[:,48:]
# element_positions = element_positions[:,start_elem_idx:start_elem_idx+nactive_elements]
# element_positions = element_positions[:,apertureES]

new_apertureES = (nelements - 1) - apertureES
element_positions = element_positions[:,new_apertureES]

element_positions   

In [None]:
# skull_thickness = round(6e-3 / dx[0]) # [gridpoints]
# skull_distance_from_transducer = round(10e-3 / dx[0]) # [gridpoints]
# regions = np.zeros(N)
# skull_start_y = element_positions[1][0] - skull_distance_from_transducer
# # regions[:, skull_start_y - skull_thickness:skull_start_y] = 1

# x = np.linspace(0, N[0] * dx[0], N[0])
# frequency = 1 / 0.015
# amplitude = 2e-3
# offset = 0e-3 
# phase_shift = 2 * np.pi * frequency * (N[0]//2 * dx[0])
# squiggle = amplitude * np.sin(2 * np.pi * frequency * x + phase_shift) + offset

# squiggle_grid_points = np.round(squiggle / dx[0]).astype(int)
# for i in range(N[0]):
#     y_start = skull_start_y - skull_thickness - squiggle_grid_points[i]
#     regions[i, y_start:skull_start_y] = 1

In [None]:
skull_distance_from_transducer = round(10e-3 / dx[0]) # [gridpoints]
skull_start_y = element_positions[1][0] - skull_distance_from_transducer
skull_thickness = round(4e-3 / dx[0]) # [gridpoints]

circle_radius = round(4.516e-3 / dx[0])
circle1_from_skull = round(1.484e-3 / dx[0])
circle2_from_skull = round(6.516e-3 / dx[0])
circle_separation_x = round(7.5e-3 / dx[0])
circle_separation_y = round(5.031e-3 / dx[0])

regions = np.zeros(N, dtype=np.int_)
regions[:, skull_start_y - skull_thickness:skull_start_y] = 1

y, x = np.meshgrid(np.arange(N[0]), np.arange(N[1]))

for i in range(3):
    circle1_cond = (x - (N[0]//2 - circle_radius + 2*i*circle_separation_x))**2 + (y - (skull_start_y - circle1_from_skull))**2 < circle_radius**2
    circle2_cond = (x - (N[0]//2 - circle_radius + (2*i+1)*circle_separation_x))**2 + (y - (skull_start_y - circle2_from_skull))**2 < circle_radius**2
    regions[circle1_cond] = 1
    regions[circle2_cond] = 0

for i in range(3):
    circle1_cond = (x - (N[0]//2 - circle_radius - 2*i*circle_separation_x))**2 + (y - (skull_start_y - circle1_from_skull))**2 < circle_radius**2
    circle2_cond = (x - (N[0]//2 - circle_radius - (2*i+1)*circle_separation_x))**2 + (y - (skull_start_y - circle2_from_skull))**2 < circle_radius**2
    regions[circle1_cond] = 1
    regions[circle2_cond] = 0

regions[:, skull_start_y:] = 0

In [None]:
c0 = 1500 # speed of sound in water [m/s]
c_lens = 2160 # speed of sound in lens [m/s]
rho0 = 1000 # density of water [kg/m^3]
speed_skull = c0 * np.ones(N)
speed_skull[regions == 1] = c_lens
density_skull = rho0 * np.ones(N)
# density_skull[regions == 1] = 1800

# Transmit time-reversal

In [None]:
# define virtual transducer position beneath the skull
# virtual_positions = np.array([
#     np.linspace(0, N[0] - 1, N[0]),
#     (N[1] - transducer_depth) * np.ones(N[0]),
# ], dtype=int)
virtual_positions = element_positions.copy()
virtual_positions[1] -= 90

In [None]:
ext = [0, N[1]*dx[1], N[0]*dx[0], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.scatter(virtual_positions[1]*dx[1], virtual_positions[0]*dx[0],
            c='b', marker='o', s=5, label='virtual element')
plt.imshow(speed_skull, cmap='gray', extent=ext)
plt.axhline(y=N[1]*dx[1]/2, color='r', linestyle='--')  # Adds a horizontal line at the center
# plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave import FourierSeries
from jwave.geometry import Medium, TimeAxis

medium = Medium(domain=domain,
                sound_speed=FourierSeries(jnp.expand_dims(speed_skull, -1), domain),
                density=FourierSeries(jnp.expand_dims(density_skull, -1), domain),
                pml_size=pml[0])
time_axis = TimeAxis.from_medium(medium, cfl=0.3)

t_end_rf = raw_rf_data.shape[0] * dt_rf
time_axis = TimeAxis(time_axis.dt, t_end_rf)
t = time_axis.to_array()

In [None]:
from scipy.signal import resample
source_signal = np.squeeze(mat['TW']['Wvfm1Wy'][0,0])
new_num_samples = int((1/time_axis.dt) * len(source_signal) / 250e6)
source_signal = resample(source_signal, new_num_samples)

# Plot time-domain signal
plt.subplot(1, 2, 1)
plt.plot(t[:len(source_signal)], source_signal)
plt.title('Time Domain Signal')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')

# Compute and plot frequency spectrum
frequency = np.fft.fftfreq(source_signal.size, d=time_axis.dt)
magnitude = np.abs(np.fft.fft(source_signal))
plt.subplot(1, 2, 2)
plt.plot(frequency, magnitude)
plt.xlim(0, 5e6)
plt.title('Frequency Spectrum')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.show()

In [None]:
source_signal = np.pad(source_signal, (0, int(time_axis.Nt - source_signal.size)), 'constant')
source_signal.shape

In [None]:
# from jwave_utils import get_plane_wave_excitation
# angle = 0
# virtual_sources, virtual_signal, virtual_carrier_signal = get_plane_wave_excitation(domain, time_axis, 1e6, transducer_frequency, dx[0], virtual_positions, angle=angle)

from jwave.geometry import Sources
sources = Sources(
        positions=tuple(map(tuple, element_positions)),
        signals=jnp.vstack([source_signal] * element_positions.shape[1]),
        dt=time_axis.dt,
        domain=domain,
    )

plt.plot(sources.signals[0])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
# from jwave_utils import get_data

# # simulate data using jwave
# virtual_pressure_skull, virtual_data_skull = get_data(speed, density, domain, time_axis_sim, virtual_sources, element_positions)

In [None]:
# from jwave.utils import show_field

# t_idx = 400
# show_field(pressure_skull[t_idx])
# plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
# plt.show()

In [None]:
# reverse_threshold = 750
# virtual_data = jnp.squeeze(virtual_pressure_skull.params[:, element_positions[0], element_positions[1]])
# virtual_data = jnp.flip(virtual_data, axis=0)
# virtual_data = virtual_data.at[:int(time_axis_sim.Nt - reverse_threshold), :].set(0)
# virtual_data = jnp.roll(virtual_data, - int(time_axis_sim.Nt - reverse_threshold), axis=0)
# plt.imshow(virtual_data, aspect='auto', cmap='seismic')   
# plt.show()

In [None]:
# from jwave.geometry import Sources

# time_reversed_sources = Sources(
#     positions=tuple(map(tuple, element_positions)),
#     signals=jnp.array(virtual_data.T),
#     dt=time_axis_sim.dt,
#     domain=domain,
# )

# plt.plot(time_axis_sim.to_array(), time_reversed_sources.signals[0])
# plt.xlabel('Time point')
# plt.ylabel('Amplitude [Pa]')
# plt.show()

In [None]:
# pressure_skull, data_skull = get_data(speed, density, domain, time_axis_sim, time_reversed_sources, virtual_positions)

In [None]:
# from jwave.utils import show_field

# t_idx = 800
# show_field(pressure_skull[t_idx])
# plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
# plt.show()

# Receive time-reversal

In [None]:
from scipy.signal import resample
new_num_samples = int((1/time_axis.dt) * raw_rf_data.shape[0] / (1/dt_rf))
raw_rf_data_sim = resample(raw_rf_data, new_num_samples, axis=0)
raw_rf_data_sim.shape

In [None]:
# from jwave_utils import get_plane_wave_excitation
# angle = 0
# sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, dx[0], element_positions, angle=angle)
# _, data = get_data(speed, density, domain, time_axis, sources, element_positions)
# _, data_homogeneous = get_data(speed_homogeneous, density_homogeneous, domain, time_axis, sources, element_positions)
# output_data = data - data_homogeneous

from jwave_utils import get_data
# DE-ABERRATION
sources = Sources(
    positions=tuple(map(tuple, element_positions)),
    signals=jnp.flip(raw_rf_data_sim[:, :, 0], axis=0).T,
    dt=time_axis.dt,
    domain=domain,
)
_, data = get_data(speed_skull, density_skull, domain, time_axis, sources, virtual_positions)
data = jnp.flip(data, axis=0)

# _, data_skull = get_data(speed_skull, density_skull, domain, time_axis, time_reversed_sources, element_positions)
pressure_skull, data_skull = get_data(speed_skull, density_skull, domain, time_axis, sources, element_positions)
sources = Sources(
    positions=tuple(map(tuple, element_positions)),
    signals=jnp.flip(data_skull, axis=0).T,
    dt=time_axis.dt,
    domain=domain,
)
_, data_skull = get_data(speed_skull, density_skull, domain, time_axis, sources, virtual_positions)
data_skull = jnp.flip(data_skull, axis=0)
output_data = data - data_skull

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

In [None]:
plt.imshow(output_data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

In [None]:
from jwave.utils import show_field

t_idx = 400
show_field(pressure_skull[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

# Reconstruction

In [None]:
from imaging.demodulate import demodulate_rf_to_iq

iq_signals, freq_carrier = demodulate_rf_to_iq(raw_rf_data[:,:,0], 1/dt_rf, freq_carrier=transducer_frequency)
iq_signals_deab, freq_carrier_deab = demodulate_rf_to_iq(output_data, 1/time_axis.dt, freq_carrier=transducer_frequency)

In [None]:
# N = domain.N[0]
Nz = domain.N[1] - transducer_depth
dx0 = domain.dx[0]

# Generate 1D arrays for x and z
x = np.linspace(-(domain.N[0]//2)*dx0, (domain.N[0]//2)*dx0, domain.N[0])
z = np.linspace(0, domain.N[1]*dx0, domain.N[1])

# Create 2D meshgrid for x and z
X, Z = np.meshgrid(x, z)

In [None]:
from imaging.beamform import beamform_delay_and_sum

beamformed_signal = beamform_delay_and_sum(iq_signals, X, Z, 1/dt_rf, freq_carrier, pitch=element_pitch, tx_delays=np.zeros(iq_signals.shape[1]))
beamformed_signal_deab = beamform_delay_and_sum(iq_signals_deab, X, Z, 1/time_axis.dt, freq_carrier_deab, pitch=element_pitch, tx_delays=np.zeros(iq_signals.shape[1]))

In [None]:
plt.imshow(np.abs(beamformed_signal), cmap='seismic', extent=ext)
plt.xlabel('[m]')
plt.ylabel('[m]')
# plt.colorbar()
plt.show()

In [None]:
plt.imshow(np.abs(beamformed_signal_deab), cmap='seismic', extent=ext)
plt.xlabel('[m]')
plt.ylabel('[m]')
# plt.colorbar()
plt.show()