In [1]:
import os
import time
import math
import numpy as np
import torch
from scipy.optimize import minimize
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

BASE_OUTPUT_DIR = os.path.join("results", "appendix")
os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(SEED)

speed_of_sound = 343.0

x_coords = np.arange(-3.5, 4.5, 1.0)
y_coords = np.arange(-3.5, 4.5, 1.0)
transducer_positions = torch.tensor(
    [(x * 0.01, y * 0.01, 0.0) for x in x_coords for y in y_coords],
    dtype=torch.float64, device=device
)
num_transducers = len(transducer_positions)

amp_table = torch.tensor([
    10.000, 11.832, 13.509, 15.414, 17.615, 18.240, 20.172, 21.626, 22.309, 23.054, 23.820, 23.820, 23.820,
    24.650, 24.650, 26.306, 29.052, 32.140, 35.496, 39.243, 41.952, 46.411, 47.958, 49.598, 53.009, 56.657,
    58.652, 60.581, 62.610, 64.885, 67.007, 69.282, 71.624, 71.624, 76.551, 76.551, 79.183, 84.676, 84.676,
    87.579, 90.388, 93.488, 93.488, 93.488, 96.695, 100.000, 96.695, 93.488, 93.488, 93.488, 90.388, 87.579,
    84.676, 84.676, 79.183, 76.551, 76.551, 71.624, 71.624, 69.282, 67.007, 64.885, 62.610, 60.581, 58.652,
    56.657, 53.009, 49.598, 47.958, 46.411, 41.952, 39.243, 35.496, 32.140, 29.052, 26.306, 24.650, 24.650,
    23.820, 23.820, 23.820, 23.054, 22.309, 21.626, 20.172, 18.240, 17.615, 15.414, 13.509, 11.832, 10.000
], dtype=torch.float64, device=device)

grid_size = 201
x_vals = torch.linspace(-0.05, 0.05, grid_size, dtype=torch.float64, device=device)
y_vals = torch.linspace(-0.05, 0.05, grid_size, dtype=torch.float64, device=device)
z_vals = torch.linspace(0.0, 0.1, grid_size, dtype=torch.float64, device=device)
dx = (x_vals[1] - x_vals[0]).item()
dy = (y_vals[1] - y_vals[0]).item()
dz = (z_vals[1] - z_vals[0]).item()
half = 2

frequency = 40000.0
wavelength = speed_of_sound / frequency

METHOD_NAME = "regularized_hybrid_vortex"
W_LAPLACIAN = 1.0
ALPHA_F = 500.0
BETA_P = 5e-5
W_LAPLACIAN_XX, W_LAPLACIAN_YY, W_LAPLACIAN_ZZ = 1, 1, 1
CURRENT_WEIGHTS = dict(w_lap=W_LAPLACIAN, alpha=ALPHA_F, beta=BETA_P)
PRESSURE_PENALTY_MODE = 'smooth_abs'
EPS_REL = 1e-3


def compute_pressure_field_torch_v2(phase_vector, amp_table, wavelength, X, Y, Z, transducer_positions):
    k_val = 2.0 * math.pi / wavelength
    grid_points = torch.stack([X.reshape(-1), Y.reshape(-1), Z.reshape(-1)], dim=1)
    delta = grid_points.unsqueeze(1) - transducer_positions.unsqueeze(0)
    R = torch.linalg.norm(delta, dim=2).clamp_min(1e-9)
    cos_theta = torch.clamp(delta[:, :, 2] / R, -1.0, 1.0)
    theta_deg = torch.rad2deg(torch.acos(cos_theta)).clamp(0.0, 90.0)

    max_idx = amp_table.shape[0] - 1
    low_index_f = theta_deg
    low_index = torch.floor(low_index_f).long()
    high_index = torch.clamp(low_index + 1, max=max_idx)
    frac = (low_index_f - low_index.to(torch.float64))
    A_low = amp_table[low_index]
    A_high = amp_table[high_index]
    A_theta = A_low + frac * (A_high - A_low)

    amplitude = A_theta / R
    propagation_phase = k_val * R
    total_phase = phase_vector.unsqueeze(0) + propagation_phase
    p_complex = torch.polar(amplitude, total_phase)
    p_field = torch.sum(p_complex, dim=1)
    return p_field.reshape(X.shape)

def compute_gradient_torch(U, dx, dy, dz):
    return torch.gradient(U, spacing=(dx, dy, dz), edge_order=1)

def compute_laplacian_weighted_torch(U, dx, dy, dz):
    grad_U_x, grad_U_y, grad_U_z = compute_gradient_torch(U, dx, dy, dz)
    L_xx, _, _ = torch.gradient(grad_U_x, spacing=(dx, dy, dz), edge_order=1)
    _, L_yy, _ = torch.gradient(grad_U_y, spacing=(dx, dy, dz), edge_order=1)
    _, _, L_zz = torch.gradient(grad_U_z, spacing=(dx, dy, dz), edge_order=1)
    return (W_LAPLACIAN_XX * L_xx + W_LAPLACIAN_YY * L_yy + W_LAPLACIAN_ZZ * L_zz)

def pressure_penalty(p2_center, p_abs_center, p2_local_rms):
    if PRESSURE_PENALTY_MODE == 'smooth_abs':
        eps = EPS_REL * (p2_local_rms + 1e-32)
        return torch.sqrt(p2_center + eps*eps)
    else:
        return torch.sqrt(p2_center + 1e-32)

def compute_gorkov_objective_local_torch(pf, dx, dy, dz):
    rho0, c0, rho_p, c_p = 1.225, 343.0, 100.0, 2400.0
    omega = 2 * math.pi * frequency
    r = 1.3e-3 / 2
    V = 4/3 * math.pi * r**3
    K1 = 0.25 * V * (1 / (c0**2 * rho0) - 1 / (c_p**2 * rho_p))
    K2 = 0.75 * V * ((rho0 - rho_p) / (omega**2 * rho0 * (rho0 + 2 * rho_p)))

    abs_p2 = (pf.real**2 + pf.imag**2)
    dpdx, dpdy, dpdz = compute_gradient_torch(pf, dx, dy, dz)
    v_sq = (dpdx.real**2 + dpdx.imag**2) + (dpdy.real**2 + dpdy.imag**2) + (dpdz.real**2 + dpdz.imag**2)

    U = K1 * abs_p2 - K2 * v_sq
    
    lapU_weighted = compute_laplacian_weighted_torch(U, dx, dy, dz)
    gradUx, gradUy, gradUz = compute_gradient_torch(U, dx, dy, dz)
    
    center_idx = (half, half, half)
    laplacian_center_weighted = lapU_weighted[center_idx]
    grad_mag_center = torch.sqrt(gradUx[center_idx]**2 + gradUy[center_idx]**2 + gradUz[center_idx]**2)
    p2_center = abs_p2[center_idx]
    p_abs_center = torch.sqrt(p2_center + 1e-32)
    p2_local_rms = torch.sqrt(torch.mean(abs_p2))
    
    p_pen = pressure_penalty(p2_center, p_abs_center, p2_local_rms)
    
    w_lap = CURRENT_WEIGHTS['w_lap']
    alpha = CURRENT_WEIGHTS['alpha']
    beta  = CURRENT_WEIGHTS['beta']

    metric = (w_lap * laplacian_center_weighted
              - alpha * grad_mag_center
              - beta  * p_pen)
    return metric

def objective_fn_torch(ph_tensor, x_idx, y_idx, z_idx):
    x_local = x_vals[x_idx-half:x_idx+half+1]
    y_local = y_vals[y_idx-half:y_idx+half+1]
    z_local = z_vals[z_idx-half:z_idx+half+1]
    Xl, Yl, Zl = torch.meshgrid(x_local, y_local, z_local, indexing='ij')

    pf_local = compute_pressure_field_torch_v2(ph_tensor, amp_table, wavelength, Xl, Yl, Zl, transducer_positions)
    gorkov_metric = compute_gorkov_objective_local_torch(pf_local, dx, dy, dz)
    return -gorkov_metric

current_indices = (0,0,0)
def set_indices(ix, iy, iz):
    global current_indices
    current_indices = (ix, iy, iz)

def objective_for_scipy(phases_np):
    phases_torch = torch.tensor(phases_np, dtype=torch.float64, device=device)
    loss = objective_fn_torch(phases_torch, *current_indices)
    return float(loss.item())

def jacobian_for_scipy(phases_np):
    phases_torch = torch.tensor(phases_np, dtype=torch.float64, device=device, requires_grad=True)
    loss = objective_fn_torch(phases_torch, *current_indices)
    (grad,) = torch.autograd.grad(loss, phases_torch, retain_graph=False, create_graph=False)
    return grad.detach().cpu().numpy().astype(np.float64)

GRID_N = (81, 81, 81)
N_ISO = 8
LEVEL_MODE = "quantile"
QUANTILE_RANGE = (0.75, 0.995)
EXP_ALPHA = 3.5
OPACITY_CLIP = (0.06, 0.9)

def _compute_pressure_field_abs(phase_vec_np, grid_bounds, grid_n, current_wavelength):
    ph = torch.tensor(phase_vec_np, dtype=torch.float64, device=device)
    (xmin,xmax),(ymin,ymax),(zmin,zmax) = grid_bounds
    nx,ny,nz = grid_n

    X1 = torch.linspace(xmin, xmax, nx, dtype=torch.float64, device=device)
    Y1 = torch.linspace(ymin, ymax, ny, dtype=torch.float64, device=device)
    Z1 = torch.linspace(zmin, zmax, nz, dtype=torch.float64, device=device)
    X, Y, Z = torch.meshgrid(X1, Y1, Z1, indexing='ij')

    M = X.numel()
    k_val = 2.0 * math.pi / current_wavelength

    batch_points = 120_000
    p_re = torch.zeros(M, dtype=torch.float64, device=device)
    p_im = torch.zeros(M, dtype=torch.float64, device=device)

    with torch.no_grad():
        idx = 0
        while idx < M:
            j = min(idx + batch_points, M)
            Xc = X.reshape(-1)[idx:j]
            Yc = Y.reshape(-1)[idx:j]
            Zc = Z.reshape(-1)[idx:j]

            grid_pts = torch.stack([Xc, Yc, Zc], dim=1)
            delta = grid_pts.unsqueeze(1) - transducer_positions.unsqueeze(0)
            R = torch.linalg.norm(delta, dim=2).clamp_min(1e-9)
            cos_th = torch.clamp(delta[:,:,2] / R, -1.0, 1.0)
            theta_deg = torch.rad2deg(torch.acos(cos_th)).clamp(0.0, 90.0)

            max_idx = amp_table.shape[0]-1
            low = torch.floor(theta_deg).long()
            high = torch.clamp(low + 1, max=max_idx)
            frac = (theta_deg - low.to(torch.float64))
            A = amp_table[low] + frac * (amp_table[high] - amp_table[low])

            amp = A / R
            prop_phase = k_val * R
            total_phase = ph.unsqueeze(0) + prop_phase

            p_chunk = torch.polar(amp, total_phase).sum(dim=1)
            p_re[idx:j] = p_chunk.real
            p_im[idx:j] = p_chunk.imag
            idx = j

    p_abs = torch.abs(torch.complex(p_re, p_im)).reshape(nx, ny, nz).detach().cpu().numpy()
    return X1.detach().cpu().numpy(), Y1.detach().cpu().numpy(), Z1.detach().cpu().numpy(), p_abs

def _compute_levels_individual(pabs, mode, n_iso, qrange):
    if mode == "quantile":
        q_low, q_high = qrange
        qs = np.linspace(q_low, q_high, n_iso)
        levels = np.quantile(pabs, qs)
        pmin = float(np.percentile(pabs, 1.0))
        pmax = float(np.percentile(pabs, 99.8))
    else:
        pmin = float(pabs.min())
        pmax = float(pabs.max())
        levels = np.linspace(pmin, pmax, n_iso+2)[1:-1]
    
    if pmax <= pmin:
        pmin, pmax = float(pabs.min()), float(pabs.max())
    
    return np.asarray(levels, dtype=float), float(pmin), float(pmax)

def _add_isosurfaces_individual(fig, x1d, y1d, z1d, pabs, levels, pmin, pmax, subplot_ref, show_colorbar):
    Xg, Yg, Zg = np.meshgrid(x1d, y1d, z1d, indexing='ij')
    vals = pabs
    
    denom = (pmax - pmin) if (pmax > pmin) else 1.0
    
    for i, lev in enumerate(levels):
        lev_norm = (lev - pmin) / denom
        opacity = float(np.clip(np.exp(EXP_ALPHA * (lev_norm - 1.0)), *OPACITY_CLIP))
        
        fig.add_trace(
            go.Isosurface(
                x=Xg.flatten(), y=Yg.flatten(), z=Zg.flatten(),
                value=vals.flatten(),
                isomin=lev, isomax=lev,
                surface_count=1,
                caps=dict(x_show=False, y_show=False, z_show=False),
                opacity=opacity,
                colorscale="Viridis",
                cmin=pmin, cmax=pmax,
                showscale=show_colorbar and (i == len(levels)-1),
                colorbar=dict(title="|p| (Pa)", x=1.02 if subplot_ref[1]==2 else -0.02) if show_colorbar else None,
                name=f"|p|={lev:.1f}",
                hovertemplate=f"Level: {lev:.1f}<br>x: %{{x:.3f}}<br>y: %{{y:.3f}}<br>z: %{{z:.3f}}<extra></extra>"
            ),
            row=subplot_ref[0], col=subplot_ref[1]
        )

def create_grouped_plot(freq1, phases1, freq2, phases2, target_coords):
    print(f"\n[PLOT] Generating comparison plot (Individual Style): {freq1}Hz vs {freq2}Hz...")
    
    tx, ty, tz = target_coords
    xlim = (tx - 0.02, tx + 0.02)
    ylim = (ty - 0.02, ty + 0.02)
    zlim = (max(0.005, tz - 0.025), min(0.085, tz + 0.055))
    grid_bounds = (xlim, ylim, zlim)
    
    wl1 = speed_of_sound / freq1
    wl2 = speed_of_sound / freq2
    
    print(f"  -> Computing {freq1}Hz volume...")
    x1d, y1d, z1d, pabs_L = _compute_pressure_field_abs(phases1, grid_bounds, GRID_N, wl1)
    
    print(f"  -> Computing {freq2}Hz volume...")
    _, _, _, pabs_R = _compute_pressure_field_abs(phases2, grid_bounds, GRID_N, wl2)

    levels_L, pmin_L, pmax_L = _compute_levels_individual(pabs_L, LEVEL_MODE, N_ISO, QUANTILE_RANGE)
    levels_R, pmin_R, pmax_R = _compute_levels_individual(pabs_R, LEVEL_MODE, N_ISO, QUANTILE_RANGE)

    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        column_widths=[0.5, 0.5],
        horizontal_spacing=0.02,
        subplot_titles=(f"{freq1} Hz (Max: {pmax_L:.1f} Pa)", f"{freq2} Hz (Max: {pmax_R:.1f} Pa)")
    )
    
    _add_isosurfaces_individual(fig, x1d, y1d, z1d, pabs_L, levels_L, pmin_L, pmax_L, (1,1), show_colorbar=False)
    _add_isosurfaces_individual(fig, x1d, y1d, z1d, pabs_R, levels_R, pmin_R, pmax_R, (1,2), show_colorbar=True)

    common_scene = dict(
        xaxis=dict(title="x (m)", range=xlim),
        yaxis=dict(title="y (m)", range=ylim),
        zaxis=dict(title="z (m)", range=zlim),
        aspectmode="data", 
        camera=dict(eye=dict(x=1.5, y=1.5, z=1.2))
    )
    
    fig.update_layout(
        title=dict(text=f"Comparison: {freq1} Hz vs {freq2} Hz (Indiv. Scale)", x=0.5),
        scene=common_scene,
        scene2=common_scene,
        margin=dict(l=0, r=0, t=60, b=0),
        width=1200, height=800
    )
    
    output_filename = f"Figure_{int(freq1)}_{int(freq2)}.html"
    full_path = os.path.join(BASE_OUTPUT_DIR, output_filename)
    
    div_id = f"plotly_div_{output_filename.replace('.', '_')}"
    post_script = f"""
    (function(){{
      var gd = document.getElementById('{div_id}');
      if(!gd) return;
      var syncing = false;
      gd.on('plotly_relayout', function(e){{
        if(syncing) return;
        var up = {{}};
        if(e['scene.camera'])  {{ up['scene2.camera'] = e['scene.camera']; }}
        if(e['scene2.camera']) {{ up['scene.camera']  = e['scene2.camera']; }}
        if(Object.keys(up).length){{
          syncing = true;
          Plotly.relayout(gd, up).then(function(){{ syncing = false; }});
        }}
      }});
    }})();
    """
    html_str = pio.to_html(fig, include_plotlyjs="cdn", full_html=True, div_id=div_id, post_script=post_script)
    with open(full_path, "w", encoding="utf-8") as f:
        f.write(html_str)
    
    print(f"[OK] Saved Individual Scale Plot: {full_path}")

if __name__ == "__main__":
    TARGET_COORD = (0.0, 0.0, 0.03)
    freq_pairs = [(40000.0, 60000.0), (80000.0, 100000.0)]

    print(f"[INFO] Target: {TARGET_COORD}")
    print(f"[INFO] Method: {METHOD_NAME}")
    print(f"[INFO] Processing Pairs: {freq_pairs}")

    results = {}

    for f1, f2 in freq_pairs:
        pair_results = []
        for freq in [f1, f2]:
            print(f"\n--- Optimizing for {freq} Hz ---")
            frequency = float(freq)
            wavelength = speed_of_sound / frequency
            
            target_coord_x, target_coord_y, target_coord_z = TARGET_COORD
            x_idx = torch.argmin(torch.abs(x_vals - target_coord_x)).item()
            y_idx = torch.argmin(torch.abs(y_vals - target_coord_y)).item()
            z_idx = torch.argmin(torch.abs(z_vals - target_coord_z)).item()
            set_indices(x_idx, y_idx, z_idx)
            
            initial_phases = np.random.rand(num_transducers) * 2.0 * math.pi
            start_time = time.time()
            res = minimize(
                fun=objective_for_scipy,
                x0=initial_phases,
                method='BFGS',
                jac=jacobian_for_scipy,
                options={'maxiter': 5000, 'disp': False, 'gtol': 1e-7}
            )
            end_time = time.time()
            print(f"  -> Done. Loss: {res.fun:.4e} (Steps: {res.nit}, Time: {end_time - start_time:.2f}s)")
            pair_results.append((freq, res.x))
        
        (freqA, phA), (freqB, phB) = pair_results
        create_grouped_plot(freqA, phA, freqB, phB, TARGET_COORD)

    print("\n[INFO] All processing completed.")

[INFO] Using device: cuda
[INFO] Target: (0.0, 0.0, 0.03)
[INFO] Method: regularized_hybrid_vortex
[INFO] Processing Pairs: [(40000.0, 60000.0), (80000.0, 100000.0)]

--- Optimizing for 40000.0 Hz ---
  -> Done. Loss: -7.8409e+00 (Steps: 73, Time: 1.70s)

--- Optimizing for 60000.0 Hz ---
  -> Done. Loss: -1.5356e+01 (Steps: 59, Time: 0.63s)

[PLOT] Generating comparison plot (Individual Style): 40000.0Hz vs 60000.0Hz...
  -> Computing 40000.0Hz volume...
  -> Computing 60000.0Hz volume...
[OK] Saved Individual Scale Plot: results/appendix/Figure_40000_60000.html

--- Optimizing for 80000.0 Hz ---
  -> Done. Loss: -2.3283e+01 (Steps: 49, Time: 0.74s)

--- Optimizing for 100000.0 Hz ---
  -> Done. Loss: -2.9616e+01 (Steps: 54, Time: 0.56s)

[PLOT] Generating comparison plot (Individual Style): 80000.0Hz vs 100000.0Hz...
  -> Computing 80000.0Hz volume...
  -> Computing 100000.0Hz volume...
[OK] Saved Individual Scale Plot: results/appendix/Figure_80000_100000.html

[INFO] All processing