In [3]:
import jax.numpy as jnp
import plotly.graph_objects as go  # pyright: ignore[reportMissingTypeStubs]
from jax import jit, vmap

from qpm import cwes, grating, mgoslt

compute_shg_vectors = jit(vmap(cwes.simulate_twm, in_axes=(None, 0, 0, None)))


def main() -> None:
    """Main function to run the chirped SHG simulation and plot results."""
    # --- Parameters ---
    num_domains = 600
    initial_width = 3.6  # μm
    chirp_rate = 0.0001  # Chirp rate
    kappa_mag = 1.31e-5 / (2 / jnp.pi)  # Nonlinear coupling coefficient magnitude
    temperature = 70.0  # Operating temperature (°C)
    wl_start, wl_end, num_points = 0.931, 1.066, 1000

    # --- Setup & Simulation ---
    print("1. Building grating and preparing simulation inputs...")
    profile = grating.tapered_profile(num_domains, initial_width, chirp_rate, kappa_mag)
    superlattice = grating.build(profile)
    grating.visualize(superlattice)

    wls = jnp.linspace(wl_start, wl_end, num_points)
    # Unpack the phase mismatch values directly
    delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, temperature)
    delta_k2s = mgoslt.calc_twm_delta_k(wls, wls / 2, temperature)
    b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex64)

    print("2. Running SHG simulation...")
    # Pass the unpacked arrays to the vmapped function.
    b_final_vectors = compute_shg_vectors(superlattice, delta_k1s, delta_k2s, b_initial)

    # Calculate the desired efficiency from the state vectors.
    efficiencies = jnp.abs(b_final_vectors[:, 1]) ** 2

    # --- Plotting Results ---
    print("3. Plotting results...")
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=wls,
            y=efficiencies * 100,
            mode="lines",
            name="SHG Conversion Efficiency",
            line={"color": "blue", "width": 2},
        ),
    )
    fig.update_layout(
        title_text=f"Chirped SHG Spectrum ({num_domains} domains)",
        xaxis_title="Fundamental Wavelength (μm)",
        yaxis_title="SHG Conversion Efficiency (%/W)",
        template="plotly_white",
    )
    fig.show()

In [4]:
main()

1. Building grating and preparing simulation inputs...


2. Running SHG simulation...
3. Plotting results...
