In [5]:
import jax

jax.config.update("jax_enable_x64", val=True)

In [None]:
import jax.numpy as jnp

from qpm import mgoslt


def calculate_local_shg_amplitudes(
    domain_widths: jax.Array,
    kappa_vals: jax.Array,
    delta_k: jax.Array,
    b_initial: jax.Array,
) -> jax.Array:
    gamma = delta_k / 2.0
    a_omega_sq = b_initial**2
    gamma_l = gamma * domain_widths
    sinc_term = jnp.sinc(gamma_l / jnp.pi)
    return -1j * kappa_vals * a_omega_sq * domain_widths * jnp.exp(1j * gamma_l) * sinc_term


def simulate_shg_npda(
    domain_widths: jax.Array,
    kappa_vals: jax.Array,
    delta_k: jax.Array,
    b_initial: jax.Array,
) -> jax.Array:
    local_amplitudes = calculate_local_shg_amplitudes(domain_widths, kappa_vals, delta_k, b_initial)
    z_starts = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(domain_widths[:-1])])
    phase_factors = jnp.exp(1j * delta_k * z_starts)
    return jnp.sum(local_amplitudes * phase_factors)


def simulate_shg_npda_trace(
    domain_widths: jax.Array,
    kappa_vals: jax.Array,
    delta_k: jax.Array,
    b_initial: jax.Array,
) -> tuple[jax.Array, jax.Array]:
    local_amplitudes = calculate_local_shg_amplitudes(domain_widths, kappa_vals, delta_k, b_initial)
    z_starts = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(domain_widths[:-1])])
    phase_factors = jnp.exp(1j * delta_k * z_starts)
    terms_to_sum = local_amplitudes * phase_factors
    cumulative_amplitudes = jnp.cumsum(terms_to_sum)
    shg_amplitude_trace = jnp.concatenate([jnp.array([0.0j]), cumulative_amplitudes])
    z_coords = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(domain_widths)])

    return z_coords, shg_amplitude_trace

In [7]:
design_temp = 70.0
b_initial = jnp.array(1.0 + 0.0j)
design_wl = 1.031
# num_domains = 555
optimized_widths = jnp.load("../datasets/optimized_thg_2000_1489_e4.npy")
domain_widths = optimized_widths
num_domains = domain_widths.shape[0]

kappa_mag = 1.31e-5 / (2 / jnp.pi)
kappa_vals = kappa_mag * ((-1) ** jnp.arange(num_domains))

delta_k1_design = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
# shg_width = jnp.pi / delta_k1_design
# domain_widths = jnp.array([shg_width] * num_domains)

z_coords, full_trace = simulate_shg_npda_trace(domain_widths, kappa_vals, delta_k1_design, b_initial)

In [8]:
import plotly.graph_objects as go

shg_intensity = jnp.abs(full_trace)

# Plotlyでグラフを作成・表示
fig = go.Figure(data=go.Scatter(x=z_coords, y=shg_intensity, mode="lines", name="SHG Intensity"))

fig.update_layout(
    title="Second-Harmonic Generation Growth (NPDA Simulation)",
    xaxis_title="Propagation Distance, z (μm)",
    yaxis_title="SHG Intensity, |A₂ω|",
)

fig.show()