In [2]:
import jax

jax.config.update("jax_platforms", "cpu")

In [4]:
from collections.abc import Callable

import jax.numpy as jnp
from jax import jit, vmap

from qpm import cwes, mgoslt


def create_thg_power_evaluator(
    num_domains: int,
    design_wl: float,
    design_temp: float,
    kappa_mag: float,
    b_initial: jax.Array,
) -> Callable[[jax.Array], jax.Array]:
    """
    Creates a JIT-compiled and vectorized function to evaluate THG power.

    The returned function takes an array of `num_domains_shg` and efficiently
    computes the THG power for each case at the specified design wavelength.
    """
    # 1. Pre-calculate all constants outside the mapped function
    delta_k1 = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
    delta_k2 = mgoslt.calc_twm_delta_k(design_wl, design_wl / 2, design_temp)
    shg_domain_width = jnp.pi / delta_k1
    sfg_domain_width = jnp.pi / delta_k2

    # Kappa values depend only on the total number of domains, so pre-calculate
    kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(num_domains))
    domain_indices = jnp.arange(num_domains)

    # 2. Define the core function to be vectorized
    def _calculate_power_at_split(num_domains_shg: int) -> jax.Array:
        """Calculates THG power for a single SHG/SFG split."""
        # Use `jnp.where` for JIT-compatibility
        widths = jnp.where(domain_indices < num_domains_shg, shg_domain_width, sfg_domain_width)
        # Simulate for the single design wavelength
        b_final = cwes.simulate_twm(widths, kappa_vals, delta_k1, delta_k2, b_initial)
        return jnp.abs(b_final[2]) ** 2

    # 3. Return the JIT-compiled, vmapped function
    return jit(vmap(_calculate_power_at_split))


# --- Example Usage ---

# Define simulation parameters
num_domains = 1000
design_wl = 1.031
design_temp = 70.0
kappa_mag = 1.31e-5 / (2 / jnp.pi)
b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex64)

# 1. Create the specialized evaluation function
evaluate_powers = create_thg_power_evaluator(num_domains, design_wl, design_temp, kappa_mag, b_initial)

# 2. Define the range of SHG domain counts to test
shg_domain_counts = jnp.arange(1, num_domains)

# 3. Run the highly optimized calculation
thg_powers = evaluate_powers(shg_domain_counts)

# 4. Find the optimal result
best_idx = jnp.argmax(thg_powers)
optimal_num_shg = shg_domain_counts[best_idx]
optimal_ratio = optimal_num_shg / num_domains
max_power = thg_powers[best_idx]

print(f"Optimal Ratio: {optimal_ratio:.4f}")
print(f"Optimal SHG Domains: {optimal_num_shg}")
print(f"Max Power: {max_power:.6f}")

Optimal Ratio: 0.5000
Optimal SHG Domains: 500
Max Power: 0.000000


In [12]:
import jax
import jax.numpy as jnp
import plotly.graph_objects as go  # pyright: ignore[reportMissingTypeStubs]


def show_power_spectrum(wls: jax.Array, powers: jax.Array) -> None:
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=wls, y=powers, mode="lines", name="Power"))
    fig.update_layout(
        title_text="Power Spectrum",
        xaxis_title="Wavelength (μm)",
        yaxis_title="Power",
    )
    fig.show()


batch_simulate_twm = jit(vmap(cwes.simulate_twm, in_axes=(None, None, 0, 0, None)))

design_wl = 1.031
design_temp = 70.0
shg_sfg_ratio = optimal_ratio
num_domains_shg = int(jnp.floor(num_domains * shg_sfg_ratio))

delta_k1 = mgoslt.calc_twm_delta_k(design_wl, design_wl, design_temp)
delta_k2 = mgoslt.calc_twm_delta_k(design_wl, design_wl / 2, design_temp)

shg_domain_width = jnp.pi / delta_k1
widths_shg = jnp.array([shg_domain_width] * num_domains)

kappa_mag = 1.31e-5 / (2 / jnp.pi)
kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(num_domains))

wls = jnp.linspace(1.025, 1.035, 1000)
delta_k1s = mgoslt.calc_twm_delta_k(wls, wls, 70.0)
delta_k2s = mgoslt.calc_twm_delta_k(wls, wls / 2, 70.0)

b_initial = jnp.array([1.0, 0.0, 0.0], dtype=jnp.complex64)
b_end_shg = batch_simulate_twm(widths_shg, kappa_vals, delta_k1s, delta_k2s, b_initial)

shw_powers = jnp.abs(b_end_shg[:, 1]) ** 2 * 100 * 1.07 / 2.84

show_power_spectrum(wls, shw_powers)

sfg_domain_width = jnp.pi / delta_k2
num_domains_sfg = num_domains - num_domains_shg
print(num_domains_shg, num_domains_sfg)
widths_shg = jnp.array([shg_domain_width] * num_domains_shg)
widths_sfg = jnp.array([sfg_domain_width] * num_domains_sfg)
widths_thg = jnp.concatenate([widths_shg, widths_sfg])
kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(widths_thg.shape[0]))
b_final = batch_simulate_twm(widths_thg, kappa_vals, delta_k1s, delta_k2s, b_initial)

thw_powers = jnp.abs(b_final[:, 2]) ** 2 * 100 * 1.07 / 2.84

show_power_spectrum(wls, thw_powers)

device_length = jnp.sum(widths_thg)
print(f"Device length: {device_length}")
print(f"Max power / (Device length)^2: {jnp.max(thw_powers) / device_length**2}")

500 500


Device length: 2290.19482421875
Max power / (Device length)^2: 1.4788141414548428e-12


In [None]:
widths = jnp.load("optimized_widths_1000_3.npy")
# TODO: widthsに値が負のドメインが一つだけ含まれていたのでドメインの幅を修正する。
widths = jnp.where(widths < 0, 0.98, widths)
num_domains = widths.shape[0]
kappa_vals = kappa_mag * jnp.power(-1, jnp.arange(num_domains))
wls = jnp.linspace(1.025, 1.035, 1000)
delta_k1 = mgoslt.calc_twm_delta_k(wls, wls, 70.0)
delta_k2 = mgoslt.calc_twm_delta_k(wls, wls / 2, 70.0)

b_end_shg = batch_simulate_twm(widths, kappa_vals, delta_k1s, delta_k2s, b_initial)

thw_powers = jnp.abs(b_end_shg[:, 2]) ** 2 * 100 * 1.07 / 2.84

show_power_spectrum(wls, thw_powers)

device_length = jnp.sum(widths)
print(f"Device length: {device_length}")
print(f"Max power / (Device length)^2: {jnp.max(thw_powers) / device_length**2}")

Device length: 2292.257568359375
Max power / (Device length)^2: 2.6902833259734305e-12


In [20]:
def plot_domain_widths(optimized_widths: jax.Array, periodic_widths: jax.Array) -> None:
    """Plots the domain widths before and after optimization."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=periodic_widths, mode="lines", name="Periodic", line={"dash": "dot"}))
    fig.add_trace(go.Scatter(y=optimized_widths, mode="lines", name="Optimized"))
    fig.update_layout(
        title_text="Domain Widths Comparison",
        xaxis_title="Domain Index",
        yaxis_title="Width (μm)",
        template="plotly_white",
    )
    fig.show()


plot_domain_widths(widths, widths_thg)