In [19]:
import jax

jax.config.update("jax_enable_x64", val=True)

In [20]:
import jax.numpy as jnp
import plotly.graph_objects as go
from jax import jit, vmap

from qpm import cwes2, grating, mgoslt

compute_shg_vectors = jit(vmap(cwes2.simulate_shg_npda, in_axes=(None, None, 0, None)))


def plot_domain_widths(widths: jax.Array) -> None:
    """Plots the domain widths before and after optimization."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=widths))
    fig.update_layout(
        title_text="Domain Widths",
        xaxis_title="Domain Index",
        yaxis_title="Width (μm)",
    )
    fig.show()


"""Main function to run the chirped SHG simulation and plot results."""
# --- Parameters ---
num_domains = round(15000 / 7.2 * 2)  # 15mm 7.2 \mu m period
kappa_mag = 1.5e-5  # Nonlinear coupling coefficient magnitude
temperature = 70.0  # Operating temperature (°C)
wl_start, wl_end, num_points = 0.931, 1.066, 1000

# --- Setup & Simulation ---
print("1. Building grating and preparing simulation inputs...")
chirp_rate = 0.00001  # Chirp rate
initial_width = 3.6  # μm
profile = grating.tapered_profile(num_domains, initial_width, chirp_rate, kappa_mag)
widths, kappas = grating.build(profile)

wls = jnp.linspace(wl_start, wl_end, num_points)
delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, temperature)
b_initial = jnp.array(jnp.sqrt(1.0), dtype=jnp.complex64)

print("2. Running SHG simulation...")
# Pass the unpacked arrays to the vmapped function.
b_final = compute_shg_vectors(widths, kappas, delta_k1s, b_initial)

# Calculate the desired efficiency from the state vectors.
effs = jnp.abs(b_final) ** 2 * 100

p2 = grating.uniform_profile(num_domains, 7.2, kappa_mag)
ws, ks = grating.build(p2)
b_fin = compute_shg_vectors(ws, ks, delta_k1s, b_initial)
efs = jnp.abs(b_fin) ** 2 * 100

1. Building grating and preparing simulation inputs...
2. Running SHG simulation...


In [21]:
# --- Plotting Results ---
print("3. Plotting results...")
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wls,
        y=efs,
        name="SHG Conversion Efficiency",
    ),
)
fig.update_layout(
    title_text=f"Chirped SHG Spectrum ({num_domains} domains)",
    xaxis_title="Fundamental Wavelength (μm)",
    yaxis_title="SHG Output Power (W)",
    template="plotly_white",
)
fig.show()

plot_domain_widths(widths)

3. Plotting results...
