In [2]:
# Environment variable (keep this at the top)
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

# === 1. Imports ===

# --- Third-Party Libraries ---
import jax
import jax.numpy as jnp
import os
# --- Local Project Modules ---
# Import from your installable pinn_toolkit package
from model import PINN
from pde_dimless import PDE_dimless# Assuming pde_dimless.py is part of the toolkit
from interactive_pde_suite import InteractivePDESuite
from pinn_toolkit.util import load_model

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


In [3]:
# === 2. JAX Configuration & Physical Parameters ===

# Use float64 for the high-precision PDE solver
jax.config.update("jax_enable_x64", True)

key = jax.random.PRNGKey(20987167762)

pdeparams_phys = {
    "alpha_phi": 9.62e-5, "omega_phi": 1.663e7, "M": 8.5e-10 / (2 * 5.35e7),
    "A": 5.35e7, "L": 1e-11, "c_se": 1.0, "c_le": 5100/1.43e5,
    "x_range": (-50.0e-6, 50.0e-6), "t_range": (0, 1.0e5),
    "nx": 128, "nt": 64, "l_0": 2*50.0e-6, "t_0": 1.0e5
}

pdedimless = PDE_dimless(pdeparams_phys)

span_pde = {
    'x': pdedimless.x_range_nd,
    't': pdedimless.t_range_nd,
    'L': (1e-11, 1e-10),
    'M': (1e-21, 1e-20),
    # 'alpha_phi': (1e-6, 1e-3),
    # 'omega_phi': (1e6, 1e7)
}
span_model = {
    'x': (-0.5, 0.5), 't': (0, 1), 'L': (0, 1), 'M': (0, 1),
    # 'alpha_phi': (0, 1), 'omega_phi': (0, 1)
}


# === 3. Model Initialization ===

# Switch to float32 for the neural network for better performance
jax.config.update("jax_enable_x64", False)

inp_idx = {'x': 0, 't': 1, 'L': 2, 'M': 3}
out_idx = {'phi': 0, 'c': 1}
width = 32
depth = 4
base_model = PINN(inp_idx, out_idx, span_pde, span_model, width, depth)

# Try to load a model from the test_run, if fail, proceed to next block and use base_model instead
# model_path = os.path.join('models', 'test_run_sobol_10', '3eb9a7683ee2ea47_84a2fb5a191fbc74')
# model_name = os.path.join(model_path, 'model_test_run_sobol_10_3eb9a7683ee2ea47_84a2fb5a191fbc74.pkl')
# trained_model = load_model(base_model, model_name)

In [4]:
# === 4. Interactive Analysis ===

suite = InteractivePDESuite(pdeparams_phys)
suite.create_interactive_comparison_plot(
    model=base_model, #trained_model, 
    span_pde=span_pde,
    span_model=span_model,
    num_frames=10,
    prediction_color='match'
)

Initializing InteractivePDESuite...
Solver has been JIT-compiled. Ready for interactive use.


VBox(children=(HBox(children=(FloatLogSlider(value=1e-11, continuous_update=False, description='L (Mobility)',…

Output(layout=Layout(min_height='850px'))