In [1]:
import jax

jax.config.update("jax_enable_x64", val=True)

In [2]:
import jax.numpy as jnp
from plotly import graph_objects as go

from qpm import cwes, mgoslt

batch_simulate = jax.jit(jax.vmap(cwes.simulate_shg_npda, in_axes=(None, None, 0, None)))

NORO_CRR_FACTOR = 1.07 / 2.84 * 100
design_temp = 70.0
kappa_mag = 1.31e-5 / (2 / jnp.pi)
b_initial = jnp.array(1.0 + 0.0j)
wl_start = 1.025
wl_end = 1.035
wls = jnp.linspace(wl_start, wl_end, 500)

dks = mgoslt.calc_twm_delta_k(wls, wls, design_temp)

num_domains = 5555
design_wl = 1.031
shg_width = jnp.pi / mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
widths = jnp.full((num_domains,), shg_width)
kappas = kappa_mag * ((-1) ** jnp.arange(num_domains))

amps = batch_simulate(widths, kappas, dks, b_initial)
effs = jnp.abs(amps) ** 2 * NORO_CRR_FACTOR

fig = go.Figure()
fig.add_trace(go.Scatter(x=wls, y=effs, mode="lines"))
fig.update_layout(
    title="SHG Conversion Efficiency vs Wavelength",
    xaxis_title="Wavelength (microns)",
    yaxis_title="SHG Conversion Efficiency",
)
fig.show()

In [3]:
import jax.numpy as jnp
from plotly import graph_objects as go

from qpm import cwes, mgoslt

batch_simulate = jax.jit(jax.vmap(cwes.simulate_shg_npda, in_axes=(None, None, 0, None)))

NORO_CRR_FACTOR = 1.07 / 2.84 * 100
design_temp = 70.0
kappa_mag = 1.31e-5 / (2 / jnp.pi)
b_initial = jnp.array(1.0 + 0.0j)
wl_start = 1.025
wl_end = 1.035
wls = jnp.linspace(wl_start, wl_end, 500)
num_domains = 5555
design_wl = 1.031

design_dk = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
dks = mgoslt.calc_twm_delta_k(wls, wls, design_temp)

delta_k_scan = dks - design_dk

shg_width = jnp.pi / design_dk
widths = jnp.full((num_domains,), shg_width)

n = jnp.arange(num_domains)
center = (num_domains - 1) / 2.0

z_n = n - center

# 1. 基本のSinc (帯域幅を決める)
apodization_scale_factor = 20  # 調整パラメータ
sinc_arg = z_n * apodization_scale_factor / center
base_sinc = jnp.sinc(sinc_arg)

# 2. ガウシアン窓関数 (リップルを抑える)
# sigmaを変えることで「平坦さ」と「立ち上がりの鋭さ」を調整
sigma = center * 0.5  # 調整パラメータ
gaussian_window = jnp.exp(-(z_n**2) / (2 * sigma**2))

# 3. 適用
# windowを掛け合わせる
apodization = base_sinc * gaussian_window

kappas_apodized = kappa_mag * ((-1) ** n) * apodization * 30

amps_sinc = batch_simulate(widths, kappas_apodized, dks, b_initial)
effs_sinc = jnp.abs(amps_sinc) ** 2 * NORO_CRR_FACTOR

kappas_orig = kappa_mag * ((-1) ** n)
amps_orig = batch_simulate(widths, kappas_orig, dks, b_initial)
effs_orig = jnp.abs(amps_orig) ** 2 * NORO_CRR_FACTOR

fig = go.Figure()
fig.add_trace(go.Scatter(x=delta_k_scan, y=effs_sinc, mode="lines", name="Sinc"))
fig.add_trace(go.Scatter(x=delta_k_scan, y=effs_orig, mode="lines", name="Const"))
fig.update_layout(
    title="SHG Conversion Efficiency (Sinc vs const)",
    xaxis_title=r"Phase mismatch $\Delta k$ (dks - design_dk)",
    yaxis_title="SHG Conversion Efficiency (Log)",
    legend_title=r"$\kappa (z)$",
)
fig.show()