In [None]:
import numpy as np
import jax.numpy as jnp
from jax import jit
from tqdm import tqdm
import matplotlib.pyplot as plt
from jaxopt import OptaxSolver
import optax
import time

from utils import beamform_pw, coherence_pw

In [None]:
#---- Load Data ----#

def load_simu(path):
    results = np.load(path, allow_pickle=True).item()
    iq_datas = results['raw_data']
    pitch = results['pitch']
    excitation_frequency = results['excitation_frequency']
    sampling_frequency = results['sampling_frequency']
    demod_freq = results['demod_freq']
    angles = results['angles'].reshape((-1, 1))
    
    return jnp.array(iq_datas).astype(jnp.complex64), \
        jnp.array(angles), excitation_frequency, demod_freq, sampling_frequency, \
        1.54, 128, pitch


name = 'simulation_light_41_angles.npy'
path = 'data/' + name
iq_datas, angles, excitation_frequency, demod_freq, sampling_frequency, c, n_elem, pitch \
    = load_simu(path)

N_tx = n_elem
N_rx = n_elem

probe_x = jnp.arange(n_elem) * pitch
probe_x = (probe_x - probe_x.mean()).reshape((-1, 1))

In [None]:
# Parameters definition

# aberration grid definition

d0, dd, nd = -0.5, 1/32, 32                # Directions, in rad
eps_d = (jnp.arange(nd)*dd + d0).reshape((-1, 1))

x0, dx, nx = -64*pitch, 128*pitch/64, 64   # Positions, in mm
eps_x = (jnp.arange(nx)*dx + x0).reshape((1, -1))

h_t = jnp.zeros((1, nd, nx))      # Grids
h_phi = jnp.zeros((1, nd, nx))
h = jnp.concatenate((h_t, h_phi), axis=0)

# Targets definition (grid on which the coherence is evaluated, in mm)
x = jnp.arange(100) * 0.3 - 15
z = jnp.arange(100) * 0.3 + 15
x, z = jnp.meshgrid(x, z, indexing='ij')
x = x.reshape((1, -1))
z = z.reshape((1, -1))

# Apodizations (correspond to F-number of 1)
apod_rx = (jnp.abs(jnp.atan((x-probe_x) / z)) < 0.5)
apod_tx = 1

@jit
def tv(h_t):
    return jnp.mean(jnp.abs(jnp.diff(h_t, axis=0))**2) + jnp.mean(jnp.abs(jnp.diff(h_t, axis=1))**2)
    
@jit
def loss(h):
    coh = coherence_pw(iq_datas, h[0], h[1], d0, dd, nd, x0, dx, nx,
                    probe_x, x, z, eps_d, eps_x, apod_tx, apod_rx,
                    sampling_frequency, demod_freq, angles)
    return coh + 5e-1*tv(h[0])

@jit
def display_img(h):
    # HD display
    x_hd = jnp.arange(200) * 0.15 - 15
    z_hd = jnp.arange(200) * 0.15 + 15
    x_hd, z_hd = jnp.meshgrid(x_hd, z_hd, indexing='ij')
    x_hd = x_hd.reshape((1, -1))
    z_hd = z_hd.reshape((1, -1))
    apod_rx_hd = (jnp.abs(jnp.atan((x_hd-probe_x) / z_hd)) < 0.5)
    apod_tx_hd = 1
    
    return beamform_pw(iq_datas, h[0], h[1], d0, dd, nd, x0, dx, nx,
                    probe_x, x_hd, z_hd, eps_d, eps_x, apod_tx_hd, apod_rx_hd,
                    sampling_frequency, demod_freq, angles)

In [None]:
# Display Aberrated Image

img = display_img(h)
plt.imshow(20*jnp.log10(jnp.abs(img)).reshape((200, 200)).T, cmap='gray', clim=(5, 65))

In [None]:
# Prepare and Compile
opt = OptaxSolver(opt=optax.adam(jnp.array([2e-2, 1e-2]).reshape((-1, 1, 1))),
                      fun=loss)  # Stochastic optimizer
state = opt.init_state(h)
h, state = opt.update(h, state)

In [None]:
# Run
for i in tqdm(range(100), ncols=100):
    h, state = opt.update(h, state)

In [None]:
# Display Corrected Image
img = display_img(h)
plt.imshow(20*jnp.log10(jnp.abs(img)).reshape((200, 200)).T, cmap='gray', clim=(5, 65))