In [None]:
import jax.numpy as jnp
import jax
from jax import experimental
from jax.experimental import optimizers
from jax import grad, jit, vmap
from jax import random
from scipy.special import jv
import os
import datetime
import time
import sys

import numpy as np
import pandas as pd
import time
import math
from jax import device_put

# Manuscript: "Target Acoustic Field and Transducer State Optimization using Diff-PAT"
# Authors: Tatsuki Fushimi, Kenta Yamamoto, Yoichi Ochiai
# Corresponding email: tfushimi@slis.tsukuba.ac.jp
# JAX version 0.2.17

In [None]:
print(jax.__version__)

0.2.17


In [None]:
import os
if not os.path.exists('results'):
    os.makedirs('results')

In [None]:
N=16
pitch=0.0105
size=[N,N]
tra_x = np.array([ (i % N - (N//2-0.5))*pitch for i in range(N * N)]).reshape(N, N)
tra_y = tra_x.T
tra_z = np.full((N,N), 0.0)
P0=1.0
l_ambda = 346.0 / 40000.0
k = 2.0*math.pi/l_ambda
r0 = 0.005
dropping_threshold = 1000;

settings_t = [1,2,3]
transducer_side = ['A', 'B','C']
target_side = ['i','ii','iii']
N_list = [2, 4]

if False:
  np.savetxt('transducer_x.csv', np.array( [tra_x.T.ravel()] ).ravel(), delimiter=',')
  np.savetxt('transducer_y.csv', np.array( [tra_y.T.ravel()] ).ravel(), delimiter=',')
  np.savetxt('transducer_z.csv', np.array( [tra_z.T.ravel()] ).ravel(), delimiter=',')

In [None]:
# Pre-calculating propagation parts
def prop_matrix(array_x, array_y, array_z, x_p, y_p, z_p, P0, k, r0):
  prop = []
  for xi in range(point_num):
    dist_map = jnp.sqrt( jnp.power((array_x - x_p[xi]), 2) + jnp.power((array_y - y_p[xi]), 2) + jnp.power((array_z - z_p[xi]), 2) )

    sin_alpha_map = jnp.sqrt( (jnp.power((array_x - x_p[xi]), 2) + jnp.power((array_y - y_p[xi]), 2)) ) / dist_map
    sin_alpha_map = jnp.where(sin_alpha_map == 0, jax.numpy.finfo(jnp.float64).tiny, sin_alpha_map)

    amplitude_map = ( 2 * jv(1, k*r0*sin_alpha_map) * P0 / (k*r0*sin_alpha_map*dist_map) )
    prop.append( jax.lax.complex( amplitude_map * jnp.cos(k*dist_map), amplitude_map * jnp.sin(k*dist_map) ) )
  
  prop = device_put(jnp.array(prop))

  return prop

In [None]:
@jit
def Propagation_C(Tr):
  part_A = Tr[0]
  part_B = Tr[1]

  tr_set = jax.lax.complex(part_A, part_B)
  temp_amp = jnp.abs(tr_set);
  temp_amp = temp_amp / jnp.max(temp_amp)
    
  trans_amplitude = jax.lax.complex(temp_amp, jnp.zeros(size))
  trans_phase = jnp.angle(tr_set)
  trans_phase_copmlex = jax.lax.complex(jnp.cos(trans_phase), jnp.sin(trans_phase))
  return trans_amplitude, trans_phase_copmlex

# Calculate Acoustic Pressure Field at the target point, output amplitude and phase. 
@jit 
def calculate_pressure(trans_amplitude, trans_phase_copmlex, transducer_prop, i):
  point_Re = jnp.sum(jnp.real(trans_phase_copmlex) * jnp.real(trans_amplitude*transducer_prop[i]) - jnp.imag(trans_phase_copmlex) * jnp.imag(trans_amplitude*transducer_prop[i]))
  point_Im = jnp.sum(jnp.real(trans_phase_copmlex) * jnp.imag(trans_amplitude*transducer_prop[i]) + jnp.imag(trans_phase_copmlex) * jnp.real(trans_amplitude*transducer_prop[i]))
  point_amp = jnp.sqrt( point_Re ** 2 + point_Im ** 2 )
  point_phase = jnp.angle(jax.lax.complex(point_Re, point_Im))
  return point_amp, point_phase

@jit
def ErrFunc_iii(point_amp, point_phase, target_amp, target_pha, i):
  err = (point_amp*jnp.cos(point_phase) - target_amp[i]*jnp.cos(target_pha[i]))**2 + (point_amp*jnp.sin(point_phase) - target_amp[i]*jnp.sin(target_pha[i]))**2
  return err

In [None]:
@jit
def loss_func_Ciii(Tr, transducer_prop, target_amp, target_pha):
    trans_amplitude, trans_phase_copmlex = Propagation_C(Tr)
    loss = 0
    # Evaluating Loss Function
    for i in range(point_num):
      point_amp, point_phase = calculate_pressure(trans_amplitude, trans_phase_copmlex, transducer_prop, i)

      err = ErrFunc_iii(point_amp, point_phase, target_amp, target_pha, i)
      loss += (1/point_num)*(err)
    return loss

@jit
def step_Ciii(step, opt_state, transducer_prop, target_amp, target_pha):
  value, grads = jax.value_and_grad(loss_func_Ciii)(get_params(opt_state), transducer_prop, target_amp, target_pha)
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

In [None]:
point_num = 60
theta = np.linspace(0.0, 2*math.pi, point_num)
radius = 40e-03
x_p = radius * np.sin(theta)
y_p = radius * np.cos(theta)
z_p = np.ones([1,point_num])[0]*0.1
target_amp=np.ones([1,point_num])[0]*50.0
target_pha = theta

In [None]:
rng = np.random.default_rng(2021) # For reproducibility
loss_all = []
amplitude_all = []
phase_all = []

ii = 2
jj = 2

transducer_prop = prop_matrix(tra_x, tra_y, tra_z, x_p, y_p, z_p, P0, k, r0)

Tr_A = rng.random((size[0], size[1]))
Tr_B = rng.random((size[0], size[1]))
Tr = [Tr_A, Tr_B]
# Initialize Optimizer
opt_init, opt_update, get_params = jax.experimental.optimizers.adam(0.1, b1=0.9, b2=0.999, eps=1e-08)
opt_state = opt_init(Tr)
loss_list = []
for st in range(dropping_threshold): # Bit messy to list the combination in if-loop, but JAX did not like conditionals passed on. 
  # Making static arguments inputs may improve the situation... 
  value, opt_state = step_Ciii(st, opt_state, transducer_prop, target_amp, target_pha)
  print(value)
  loss_list.append(value)
  
# Exporting Data
Tr = get_params(opt_state)    
part_A = Tr[0]
part_B = Tr[1]
tr_set = jax.lax.complex(part_A, part_B)
temp_amp = jnp.abs(tr_set);
temp_amp = temp_amp / jnp.max(temp_amp) 
amplitude = temp_amp
phase = jnp.angle(tr_set)

amplitude_exports = np.array(amplitude)
phase_exports = np.array(phase)

loss_all.append(np.array(loss_list))
amplitude_all.append(np.array(amplitude_exports.T.ravel()).ravel())
phase_all.append(np.array(phase_exports.T.ravel()).ravel())

phase_1d_arr = np.array( phase_all )
amp_1d_arr = np.array( amplitude_all )
loss_1d_arr = np.array( loss_all )
np.savetxt('results/Loss_exports_N_' + str(point_num) + '_Trans_' + transducer_side[ii] + '_Target_'+ target_side[jj] + '_settings.csv', loss_1d_arr, delimiter=',')
np.savetxt('results/Phase_exports_N_' + str(point_num) + '_Trans_' + transducer_side[ii] + '_Target_'+ target_side[jj] + '_settings.csv', phase_1d_arr, delimiter=',')
np.savetxt('results/Amplitude_exports_N_' + str(point_num) + '_Trans_' + transducer_side[ii] + '_Target_'+ target_side[jj] + '_settings.csv', amp_1d_arr, delimiter=',')
print('----- N_' + str(point_num) + '_Trans_' + transducer_side[ii] + '_Target_'+ target_side[jj] + ' COMPLETED -----')

5641.6997
2899.7002
1388.7063
582.88
207.03366
96.67074
98.33267
136.53654
172.88922
191.0702
189.49956
173.45811
149.79253
124.22917
100.50622
80.4953
64.7105
52.853104
44.243385
38.10355
33.714638
30.485767
27.970442
25.855324
23.937693
22.100393
20.288385
18.48885
16.715454
14.995741
13.362146
11.845407
10.470323
9.253069
8.200453
7.309557
6.569168
5.9612155
5.4631233
5.0502415
4.698271
4.385458
4.094235
3.8121562
3.532165
3.252129
2.9738214
2.7015839
2.4409702
2.1973739
1.9750205
1.7764008
1.6020037
1.4505534
1.3194176
1.2052609
1.1046576
1.0145926
0.93279934
0.85788924
0.7892402
0.7268195
0.6708869
0.6217198
0.57939035
0.54366267
0.5139359
0.48933622
0.46877608
0.4511303
0.43532884
0.4204513
0.4057929
0.39087555
0.3754572
0.35951102
0.34316057
0.32669502
0.31045878
0.29486325
0.28030327
0.2671125
0.25554365
0.24571805
0.23762968
0.23113616
0.22600313
0.22191593
0.21855222
0.21558927
0.21277446
0.20992082
0.20692375
0.20375213
0.2004386
0.1970409
0.19364478
0.19032766
0.18715277
0.

In [None]:
!zip -r /content/colab_results.zip /content/results/

  adding: content/results/ (stored 0%)
  adding: content/results/Loss_exports_N_60_Trans_C_Target_iii_settings.csv (deflated 57%)
  adding: content/results/Phase_exports_N_60_Trans_C_Target_iii_settings.csv (deflated 55%)
  adding: content/results/Amplitude_exports_N_60_Trans_C_Target_iii_settings.csv (deflated 55%)
