In [6]:
import jax

jax.config.update("jax_enable_x64", val=True)

In [8]:
import jax
import jax.numpy as jnp
from jax import jit, vmap

from qpm import cwes, mgoslt


def build_simulation_inputs(
    total_length: float,
    shg_domain_width: float,
    sfg_domain_width: float,
    max_total_domains: int,
    kappa_mag: float,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
    """Generates batched inputs for every valid SHG/SFG domain configuration."""
    # 1. Generate all potential domain counts
    shg_counts = jnp.arange(jnp.floor(total_length / shg_domain_width).astype(jnp.int32) + 1)
    sfg_counts = jnp.floor((total_length - shg_counts * shg_domain_width) / sfg_domain_width).astype(jnp.int32)
    total_counts = shg_counts + sfg_counts

    # 2. Filter for valid configurations where there's at least one domain
    valid_mask = (sfg_counts >= 0) & (total_counts > 0)
    shg_counts, sfg_counts, total_counts = shg_counts[valid_mask], sfg_counts[valid_mask], total_counts[valid_mask]

    # 3. Build batched inputs using vectorized operations
    domain_indices = jnp.arange(max_total_domains)
    is_shg_domain = domain_indices < shg_counts[:, None]
    is_sfg_domain = (domain_indices >= shg_counts[:, None]) & (domain_indices < total_counts[:, None])
    is_active_domain = domain_indices < total_counts[:, None]

    # REFACTOR: Nested jnp.where is more concise than sequential assignments.
    batched_widths = jnp.where(is_shg_domain, shg_domain_width, jnp.where(is_sfg_domain, sfg_domain_width, 0.0))

    base_kappas = kappa_mag * jnp.power(-1, domain_indices)
    batched_kappa_vals = jnp.where(is_active_domain, base_kappas, 0.0)

    return batched_widths, batched_kappa_vals, shg_counts, sfg_counts


# --- Simulation Parameters ---
total_length = 2e3  # µm
design_wl = 1.031
design_temp = 70.0
kappa_mag = 1.31e-4 / (2 / jnp.pi)
b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex128)

# --- Core Logic ---
delta_k1 = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
delta_k2 = mgoslt.calc_twm_delta_k(design_wl, design_wl / 2, design_temp)
shg_domain_width = jnp.pi / delta_k1
sfg_domain_width = jnp.pi / delta_k2

# 1. Generate all inputs in a single vectorized operation
max_total_domains = int(total_length / sfg_domain_width)
batched_widths, batched_kappa_vals, shg_counts, sfg_counts = build_simulation_inputs(
    total_length, shg_domain_width, sfg_domain_width, max_total_domains, kappa_mag
)

# 2. Run batched simulations
batched_simulate = jit(vmap(cwes.simulate_twm, in_axes=(0, 0, None, None, None)))
b_finals = batched_simulate(batched_widths, batched_kappa_vals, delta_k1, delta_k2, b_initial)
thg_powers = jnp.abs(b_finals[:, 2]) ** 2
thg_powers.block_until_ready()

# 3. Analyze and display the final results
best_idx = jnp.argmax(thg_powers)
max_power = thg_powers[best_idx]
optimal_num_shg = shg_counts[best_idx]
optimal_num_sfg = sfg_counts[best_idx]

optimal_shg_length = optimal_num_shg * shg_domain_width
optimal_sfg_length = optimal_num_sfg * sfg_domain_width
actual_total_length = optimal_shg_length + optimal_sfg_length
optimal_ratio = optimal_shg_length / actual_total_length

print(f"Total Device Length: {total_length:.2f} µm")
print("---")
print(f"Optimal SHG/Total Length Ratio: {optimal_ratio:.4f}")
print(f"Optimal SHG Domains: {optimal_num_shg}")
print(f"Optimal SFG Domains: {optimal_num_sfg}")
print(f"Actual Device Length in Optimal Config: {actual_total_length:.2f} µm")
print(f"Max Power: {max_power:.6f}")

Total Device Length: 2000.00 µm
---
Optimal SHG/Total Length Ratio: 0.5024
Optimal SHG Domains: 279
Optimal SFG Domains: 1016
Actual Device Length in Optimal Config: 1999.86 µm
Max Power: 0.002490


In [10]:
import jax.numpy as jnp
import plotly.graph_objects as go

batch_simulate_twm = jit(vmap(cwes.simulate_twm, in_axes=(None, None, 0, 0, None)))

design_wl = 1.031
design_temp = 70.0
kappa_mag = 1.31e-5 / (2 / jnp.pi)

delta_k1 = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
delta_k2 = mgoslt.calc_twm_delta_k(design_wl, design_wl / 2, design_temp)
wls = jnp.linspace(1.025, 1.035, 1000)

delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, 70.0)
delta_k2s = mgoslt.calc_twm_delta_k(wls, wls / 2, 70.0)

widths_shg = jnp.array([shg_domain_width] * int(optimal_num_shg))
widths_sfg = jnp.array([sfg_domain_width] * int(optimal_num_sfg))
widths_thg = jnp.concatenate([widths_shg, widths_sfg])

kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(widths_thg.shape[0]))

b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex128)
b_final = batch_simulate_twm(widths_thg, kappa_vals, delta_k1s, delta_k2s, b_initial)
thw_powers_periodical = jnp.abs(b_final[:, 2]) ** 2 * 100 * 1.07 / 2.84

device_length = jnp.sum(widths_thg)
print(optimal_num_shg, optimal_num_sfg)
print(f"Device length: {device_length}")
print(f"Max power / (Device length)^2: {jnp.max(thw_powers_periodical) / device_length**2}")

widths = jnp.load("../datasets/optimized_thg_2000_1489_e4.npy")
# widthsに値が負のドメインが一つだけ含まれていたのでドメインの幅を修正、ピークが1.031に来るように微調整した
widths = jnp.where(widths < 0, 100, widths)
num_domains = widths.shape[0]
kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(num_domains))
wls = jnp.linspace(1.025, 1.035, 1000)
delta_k1 = mgoslt.calc_twm_delta_k(wls, wls, 70.0)
delta_k2 = mgoslt.calc_twm_delta_k(wls, wls / 2, 70.0)

b_end_shg = batch_simulate_twm(widths, kappa_vals, delta_k1s, delta_k2s, b_initial)

thw_powers = jnp.abs(b_end_shg[:, 2]) ** 2 * 100 * 1.07 / 2.84

fig = go.Figure()
fig.add_trace(go.Scatter(x=wls, y=thw_powers, mode="lines", name="Optimized Structure"))
fig.add_trace(go.Scatter(x=wls, y=thw_powers_periodical, mode="lines", name="Periodic Structure"))
fig.update_layout(
    title_text="Power Spectrum",
    xaxis_title="Wavelength (μm)",
    yaxis_title="Power",
)
fig.show()

device_length = jnp.sum(widths)
print(f"Device length: {device_length}")
print(f"Max power / (Device length)^2: {jnp.max(thw_powers) / device_length**2}")

279 1016
Device length: 1999.8648792846866
Max power / (Device length)^2: 2.493852083995929e-12


Device length: 2299.97314453125
Max power / (Device length)^2: 4.30323275238301e-12


In [37]:
def plot_domain_widths(optimized_widths: jax.Array, periodic_widths: jax.Array) -> None:
    """Plots the domain widths before and after optimization."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=periodic_widths, mode="lines", name="Periodic", line={"dash": "dot"}))
    fig.add_trace(go.Scatter(y=optimized_widths, mode="lines", name="Optimized"))
    fig.update_layout(
        title_text=f"Domain Widths Comparison (length {jnp.sum(optimized_widths):.1f} µm)",
        xaxis_title="Domain Index",
        yaxis_title="Width (μm)",
        template="plotly_white",
    )
    fig.show()


plot_domain_widths(widths, widths_thg)