In [None]:
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
from jax import jit, vmap

from qpm import cwes, grating, mgoslt

compute_shg_spectrum = jit(vmap(cwes.simulate_twm, in_axes=(None, None, 0, 0, None)))


def calculate_verification_spectrum(
    wls: jax.Array,
    device_length: float,
    kappa_mag: float,
    design_temp: float,
    design_wl: float,
) -> jax.Array:
    """
    Calculates the theoretical SHG spectrum using the sinc^2 formula (NPDA).
    This serves as an analytical verification for the simulator.
    """
    # The grating vector Kg is fixed by the phase matching condition at the design point.
    k_g = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)

    # Calculate the material's phase mismatch across the entire wavelength spectrum.
    delta_k_material = mgoslt.calc_twm_delta_k(wls, wls, design_temp)

    # The total phase mismatch is the difference between the material's mismatch and the grating's.
    delta_k_total = delta_k_material - k_g

    # Standard formula for SHG efficiency in the low-conversion (non-depleted pump) limit.
    sinc_squared_part = jnp.sinc(delta_k_total * device_length / (2 * jnp.pi)) ** 2
    efficiency_raw = (kappa_mag**2) * (device_length**2) * sinc_squared_part

    return efficiency_raw * 100


def main() -> None:
    """Main function to run the simulation, verification, and plot the comparison."""
    # --- Parameters ---
    device_length = 9400.0  # Target total length of the device (μm)
    kappa_mag = 1.5e-5 / (2 / jnp.pi)  # Nonlinear coupling coefficient magnitude (μm⁻¹)
    design_temp = 70.0  # QPM design temperature (°C)
    design_wl = 1.031  # QPM design wavelength (μm)
    num_points = 1000  # Resolution of the spectrum
    wl_start, wl_end = 1.025, 1.035  # Wavelength range for simulation (μm)

    # --- Setup ---
    print("1. Calculating QPM period from design parameters...")
    shg_delta_k_design = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
    shg_qpm_period = 2 * jnp.pi / shg_delta_k_design
    print(f"    - Calculated QPM Period: {shg_qpm_period:.4f} μm")

    # Determine domain count and the actual, discretized device length.
    num_domains = int(jnp.round(2 * device_length / shg_qpm_period))
    actual_device_length = num_domains * shg_qpm_period / 2
    print(f"    - Number of Domains: {num_domains}")
    print(f"    - Actual Device Length: {actual_device_length:.2f} μm")

    wls = jnp.linspace(wl_start, wl_end, num_points)

    # --- Simulator Calculation ---
    print("\n2. Building uniform grating and running CWES simulator...")
    profile = grating.uniform_profile(num_domains, shg_qpm_period, kappa_mag)  # 野呂さんのkappaはkappa_eff
    widths, kappas = grating.build(profile)
    delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, design_temp)
    delta_k2s = mgoslt.calc_twm_delta_k(wls, wls / 2, design_temp)
    b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex64)
    simulator_effs = compute_shg_spectrum(widths, kappas, delta_k1s, delta_k2s, b_initial)
    simulator_effs = jnp.abs(simulator_effs[:, 1]) ** 2 * 100

    # --- Verification Calculation (NPDA) ---
    print("3. Calculating theoretical spectrum via NPDA (sinc^2)...")
    verification_effs = calculate_verification_spectrum(wls, actual_device_length, kappa_mag * (2 / jnp.pi), design_temp, design_wl)

    # --- Plotting Results ---
    print("4. Plotting comparison...")

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=wls, y=simulator_effs, mode="lines", name="Simulator Output", line={"color": "blue"}))
    fig.add_trace(go.Scatter(x=wls, y=verification_effs, mode="lines", name="Verification (NPDA)", line={"color": "red", "dash": "dash"}))
    fig.update_layout(
        xaxis_title="Fundamental Wavelength (μm)",
        yaxis_title="Normalized SHG Conversion Efficiency",
        title=f"Comparison of SHG Spectrum. Design λ: {design_wl} μm, Temp: {design_temp}°C",
    )
    fig.show()

In [6]:
main()

1. Calculating QPM period from design parameters...
    - Calculated QPM Period: 7.2017 μm
    - Number of Domains: 2611
    - Actual Device Length: 9401.77 μm

2. Building uniform grating and running CWES simulator...
3. Calculating theoretical spectrum via NPDA (sinc^2)...
4. Plotting comparison...
