In [1]:
import jax

jax.config.update("jax_enable_x64", val=True)

In [2]:
import jax.numpy as jnp
from jax import jit, vmap

from qpm import cwes2, mgoslt

NORO_FACTOR = 100 * 1.07 / 2.84
design_wl = 1.031
design_temp = 70.0
kappa_mag = 1.31e-5 / (2 / jnp.pi)
wls = jnp.linspace(1.025, 1.035, 1000)
num_domains_shg = int(9400 / 7.2 * 2)
num_domains_sfg = int(5600 / 1.96 * 2)
dk1_base = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
dk2_base = mgoslt.calc_twm_delta_k(design_wl, design_wl / 2, design_temp)
dk1s = mgoslt.calc_twm_delta_k(wls, wls, 70.0)
dk2s = mgoslt.calc_twm_delta_k(wls, wls / 2, 70.0)
a0_b0 = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex128)

# --- Define QPM Grating Structure (kappas and widths) ---
# Widths are set for perfect phase matching at base dk values
shg_width = jnp.pi / dk1_base
sfg_width = jnp.pi / dk2_base
widths_shg = jnp.array([shg_width] * num_domains_shg)
widths_sfg = jnp.array([sfg_width] * num_domains_sfg)
widths = jnp.concatenate([widths_shg, widths_sfg])
kappas = kappa_mag * (-1) ** jnp.arange(widths.shape[0])

batched_calc_a3 = jit(vmap(cwes2.calc_a3_npda, [None, None, None, None, 0, 0]))
effs = jnp.abs(batched_calc_a3(a0_b0[0], kappas, kappas, widths, dk1s, dk2s)) ** 2 * NORO_FACTOR

simulate_twm_precise = jit(vmap(cwes2.simulate_twm, [None, None, None, 0, 0, None]))
amps = simulate_twm_precise(widths, kappas, kappas, dk1s, dk2s, a0_b0)
effs_precise = jnp.abs(amps[:, 2]) ** 2 * NORO_FACTOR

W0104 01:24:10.382413  475287 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0104 01:24:10.387610  475185 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
E0104 01:31:31.923258  475315 slow_operation_alarm.cc:73] 
********************************
[Compiling module jit_concatenate for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
E0104 01:31:31.939358  475185 slow_operation_alarm.cc:140] The operation took 7m19.541882963s

********************************
[Compiling module jit_concatenate for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


In [3]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(x=wls, y=effs, mode="lines", name="Efficiencies (NPDA)"))
fig.add_trace(go.Scatter(x=wls, y=effs_precise, mode="lines", name="Efficiencies (Precise)"))
fig.update_layout(
    title_text="Efficiency Spectrum",
    xaxis_title="Wavelength (μm)",
    yaxis_title="Efficiency (%/W)",
)
fig.show()