In [1]:
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
from jax import jit, vmap

from qpm import cwes2, grating, mgoslt

compute_shg_vectors = jit(vmap(cwes2.simulate_twm, in_axes=(None, None, None, 0, 0, None)))


def plot_domain_widths(widths: jax.Array) -> None:
    """Plots the domain widths before and after optimization."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=widths))
    fig.update_layout(
        title_text="Domain Widths",
        xaxis_title="Domain Index",
        yaxis_title="Width (μm)",
    )
    fig.show()


"""Main function to run the chirped SHG simulation and plot results."""
# --- Parameters ---
num_domains = 600
kappa_mag = 1.31e-5 / (2 / jnp.pi)  # Nonlinear coupling coefficient magnitude
temperature = 70.0  # Operating temperature (°C)
wl_start, wl_end, num_points = 0.931, 1.066, 1000

# --- Setup & Simulation ---
print("1. Building grating and preparing simulation inputs...")
chirp_rate = 0.0001  # Chirp rate
initial_width = 3.6  # μm
profile = grating.tapered_profile(num_domains, initial_width, chirp_rate, kappa_mag)
widths, kappas = grating.build(profile)

wls = jnp.linspace(wl_start, wl_end, num_points)
# Unpack the phase mismatch values directly
delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, temperature)
delta_k2s = mgoslt.calc_twm_delta_k(wls, wls / 2, temperature)
b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex64)

print("2. Running SHG simulation...")
# Pass the unpacked arrays to the vmapped function.
b_final = compute_shg_vectors(widths, kappas, kappas, delta_k1s, delta_k2s, b_initial)

# Calculate the desired efficiency from the state vectors.
effs = jnp.abs(b_final[:, 1]) ** 2 * 100 * 1.07 / 2.84


W0104 01:28:18.597407  478254 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0104 01:28:18.600041  478158 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


1. Building grating and preparing simulation inputs...
2. Running SHG simulation...


In [2]:
# --- Plotting Results ---
print("3. Plotting results...")
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=wls,
        y=effs,
        name="SHG Conversion Efficiency",
    ),
)
fig.update_layout(
    title_text=f"Chirped SHG Spectrum ({num_domains} domains)",
    xaxis_title="Fundamental Wavelength (μm)",
    yaxis_title="SHG Conversion Efficiency (%/W)",
    template="plotly_white",
)
fig.show()

plot_domain_widths(widths)

3. Plotting results...
