In [1]:
# Environment variable
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


In [2]:
# ==============================================================================
# 1. Standard Python Libraries
# ==============================================================================
import os
import time
import warnings
from functools import partial

# ==============================================================================
# 2. Third-Party Scientific & ML Libraries
# ==============================================================================
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import config, lax, random, vmap
from jax.flatten_util import ravel_pytree
from jax.scipy.optimize import minimize

# ==============================================================================
# 3. pinn_toolkit
# ==============================================================================
from pinn_toolkit.derivative import Derivative
from pinn_toolkit.sampler import Sampler
from pinn_toolkit.train import Train
from pinn_toolkit.train_util import compute_err, generate_param
from pinn_toolkit.util import (
    L2, tree_to_f32, get_i, get_len, hex_to_key, key_to_hex,
    load_model, save_model, load_pytree, save_pytree,
    map_span, map_span_dict, split_pytree, stratified_subset,
    load_h5,
)

# ==============================================================================
# 3. local files
# ==============================================================================
from interactive_pde_suite import InteractivePDESuite
from pde_dimless import PDE_dimless
from residual import Residual

In [3]:
# set up PDE
jax.config.update("jax_enable_x64", True)
pdeparams_phys = {
    "alpha_phi": 9.62e-5, "omega_phi": 1.663e7, "M": 8.5e-10 / (2 * 5.35e7),
    "A": 5.35e7, "L": 1e-11, "c_se": 1.0, "c_le": 5100/1.43e5, "x_range": (-50.0e-6, 50.0e-6),
    "t_range": (0, 1.0e5), "nx": 32, "nt": 32, "l_0": 2*50.0e-6, "t_0": 1.0e5
}

pdedimless = PDE_dimless(pdeparams_phys)
span_pde = {
    'x':pdedimless.x_range_nd,
    't':pdedimless.t_range_nd,
    'L':(1e-12, 1e-10),
    'M':(1e-21, 1e-19)
}
span_model = {
    'x':(-0.5,0.5),
    't':(0,1),
    'L':(0,1),
    'M':(0,1)
}

# load validation data
validation_data_path = os.path.join('data', '2rqmc_5k_64*64') 
data_pde = load_h5(validation_data_path)
ref_data = tree_to_f32(map_span_dict(data_pde, span_pde, span_model))

# construct modelN
jax.config.update("jax_enable_x64", False)
#### set up the model
from model import PINN
inp_idx = {'x':0, 't':1, 'L':2, 'M':3}
out_idx =  {'phi':0, 'c':1}
# trunk_keys = ['x','t']
# latent_size = 32
# branch_size = [64,3]
# trunk_size = [64,3]
# transition_size = [32,2]
base_model = PINN(inp_idx, out_idx, span_pde, span_model, width = 32, depth = 4)

In [4]:
def train_model_single(key, pdeparams_phys, model, method, P_model, validation_chunk_size = 10, train_with_data = True):
    # ==============================================================================
    # --- 1. Preliminaries
    # ==============================================================================
    jax.config.update("jax_enable_x64", True)
    inp_idx = model.inp_idx
    out_idx = model.out_idx
    span_pde = model.span_pde
    span_model = model.span_model
    P_phys = map_span_dict(P_model, span_model, span_pde)

    # --- generate training data
    pdedimless = PDE_dimless(pdeparams_phys)
    num_gt_to_use = pdeparams_phys['nx'] * pdeparams_phys['nt'] if train_with_data else 0
    key, subkey = random.split(key)
    _, train_data = pdedimless.generate_training_data(subkey, P_phys, num_gt_to_use)
    train_data = map_span_dict(train_data, span_pde, span_model)
    
    # --- convert everything to fp32
    jax.config.update("jax_enable_x64", False)
    P_model, P_phys, train_data, span_pde, span_model = map(
        tree_to_f32,
        (P_model, P_phys, train_data, span_pde, span_model)
    )
    
    # --- config derivative
    d = Derivative(inp_idx, out_idx, span_pde, span_model)
    for deriv_name in ['phi_t', 'phi_x', 'phi_2x', 'c_t', 'c_2x']:
        d.create_deriv_fn(deriv_name)

    # --- construct residual
    r = Residual(span_pde, span_model, pdedimless, d)
    
    # --- config sampling function
    def input_single(key, P_i, train_data_i, num_data):
        keys = random.split(key, 5)
        inp = {k: {} for k in ['ic', 'bc', 'colloc', 'data']}
        # ------ ic
        inp['ic']['x'] = Sampler.get(keys[0], [16, 32, 16], [(-0.5, -0.1), (-0.1, 0.1), (0.1, 0.5)])
        inp['ic']['t'] = Sampler.get(keys[0], [64], [(0, 0)])
        # ------ bc
        inp['bc']['x'] = Sampler.get(keys[1], [32, 32], [(-0.5, -0.5), (0.5, 0.5)])
        inp['bc']['t'] = jnp.tile(Sampler.get(keys[1], [32], [(0, 1)]), 2)
        # ------ colloc
        x_colloc = Sampler.get(keys[2], [32], [(-0.5, 0.5)])
        t_colloc = Sampler.get(keys[3], [32], [(0, 1)])
        inp['colloc']['x'], inp['colloc']['t'] = map(lambda array: array.ravel(), jnp.meshgrid(x_colloc, t_colloc, indexing = "ij"))
        # ------ data
        inp['data'].update({key: train_data_i[key] for key in ['x','t','phi','c']})
            
        # Add parameters to inp
        for inp_key in inp:
            ref_array = inp[inp_key]['x']
            for param_key, param_value in P_i.items():
                inp[inp_key][param_key] = jnp.full(ref_array.shape, param_value, ref_array.dtype)
                
        return inp
    
    def new_input(key, P, train_data, num_data):
        num_params = get_len(P)
        keys = random.split(key, num_params)
        out = jax.vmap(input_single, in_axes=(0, 0, 0, None))(keys, P, train_data, num_data)
        return jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,)), out)

    def _subset(key, pytree, subsample_size):
        return stratified_subset(key, pytree, get_len(pytree), get_len(P_model), subsample_size)
    
    def subset_input(key, inp, subset_size):
        num_subsets = len(subset_size)
        subkeys = random.split(key, num_subsets)
        return {k: _subset(subkey, inp[k], size) for (k, size), subkey in zip(subset_size.items(), subkeys)}

    # ==============================================================================
    # --- Create wrapper functions for training
    # ==============================================================================
    def update_input(key, P, train_data, num_data):
        return new_input(key, P, train_data, num_data)

    def update_weight(key, params, static, inp):
        # ------ compute NTK weight
        model_temp = eqx.combine(params, static)
        ntk_size = {'ic': 4, 'bc': 4, 'colloc': 8, 'data': 8}
        subset_to_use = subset_input(key, inp, ntk_size)
        ntk_weight = r.compute_ntk_weights(model_temp, subset_to_use)
        
        # ------ apply manual weight
        manual_weight = {'ic': 1.0, 'bc': 1.0, 'ac': 1.0, 'ch': 1.0, 'data': 1.0}
        weighted = {k: ntk_weight[k] * manual_weight[k] for k in ntk_weight}
        
        # ------ noramlize final weight
        weights_array = jnp.array(list(weighted.values()))
        geom_mean = jnp.exp(jnp.mean(jnp.log(weights_array + 1e-12)))  # add epsilon to avoid log(0)
        normalized = {k: v / geom_mean for k, v in weighted.items()}
    
        return normalized

    @eqx.filter_jit
    def loss_fn(params, static, inp, weight_dict):
        model_temp = eqx.combine(params, static)
        loss_dict = r.compute_loss(model_temp, inp)
        weighted_loss = jnp.sum(jnp.array([weight_dict[k] * loss_dict[k] for k in ['ic', 'bc', 'ac', 'ch', 'data']]))
        return weighted_loss, loss_dict

    @eqx.filter_jit
    def validation_fn(params, static):
        return compute_err(params, static, ref_data, validation_chunk_size)
    
    # ==============================================================================
    # --- Initialize Training Configurations
    # ==============================================================================
    key, subkey = random.split(key)
    total_steps = 60000
    sp1 = 200  # Resampling frequency
    sp2 = 500   # NTK weighting frequency
    sp3 = 30000 # log frequency (set very high to prevent logging)
    sp4 = 100    # Validation frequency
    log_keys_order = ('ic', 'bc', 'ac', 'ch', 'data', 'L2_phi', 'L2_c') # Use a tuple
    
    # --- generate new input
    key, subkey = random.split(key)
    inp = new_input(subkey, P_model, train_data, num_gt_to_use)
    
    # --- model param and static
    model_params, model_static = eqx.partition(model, eqx.is_inexact_array)
    
    # --- optimizer
    initial_lr = 1e-3
    # Use a learning rate schedule (e.g., exponential decay)
    lr_schedule = optax.exponential_decay(
        init_value=initial_lr,
        transition_steps=total_steps,
        decay_rate=0.5 # This means LR will be (decay_rate*100)% of initial lr at the end
    )
    
    max_grad_norm = 1.0
    weight_decay_adamw = 1e-2
    
    optimizer_adamw = optax.chain(
        optax.clip_by_global_norm(max_grad_norm),
        optax.adamw(
            learning_rate=lr_schedule, # <<< USE THE SCHEDULE HERE
            b1=0.9,
            b2=0.999,
            weight_decay=weight_decay_adamw
        )
    )
    opt_state = optimizer_adamw.init(model_params)
    
    # --- initialize carry
    key, subkey = random.split(key)
    weight_dict = update_weight(key, model_params, model_static, inp)
    init_carry = (subkey, inp, weight_dict, model_params, opt_state, jnp.inf, model_params)

    # ==============================================================================
    # --- Train the model
    # ==============================================================================
    final_carry, loss_history = Train.train(
        total_steps=total_steps,
        sp1=sp1,
        sp2=sp2,
        sp3=sp3,
        sp4=sp4,
        num_data=num_gt_to_use,
        static=model_static,
        optimizer=optimizer_adamw,
        update_input=update_input,
        update_weight=update_weight,
        loss_fn=loss_fn,
        validation_fn = validation_fn,
        log_keys_order=log_keys_order,
        P_model=P_model,
        train_data=train_data,
        carry=init_carry
    )
    
    # --- reconstruction
    _, _, _, final_param, _, best_geom, param_geom = final_carry
    model_geom = eqx.combine(param_geom, model_static)
    print(f"the best loss is {best_geom}")
    return model_geom, loss_history

In [5]:
def train_model_batch(key, pdeparams_phys, base_model, method, size, note="", repetition=1,
                      param_key_override = None, out_dir = "models", validation_chunk_size = 10):
    """
    Trains models in batches, saving each model and its loss history into a dedicated subdirectory.
    """
    os.makedirs(out_dir, exist_ok=True)
    
    for i in range(repetition):
        print(f"================= Batch repetition {i+1} / {repetition} =================")
        key, subkey = random.split(key)
        param_key, train_key = random.split(subkey,2)
        prama_key = param_key_override if param_key_override else param_key # override the param_key if necessary
        param_key_hex, train_key_hex = key_to_hex(param_key), key_to_hex(train_key)
        
        # 1. create the directory and file names
        sub_dir = f"{note}_{method}_{size}" if note else f"{method}_{size}"
        subsub_dir = f"{param_key_hex}_{train_key_hex}"
        final_dir = os.path.join(out_dir, sub_dir, subsub_dir)
        model_filename = f"model_{sub_dir}_{subsub_dir}.pkl"
        loss_filename = f"loss_{sub_dir}_{subsub_dir}.npz"
        model_path = os.path.join(final_dir, model_filename)
        loss_path = os.path.join(final_dir, loss_filename)
        
        # 2. if model exists
        if os.path.exists(model_path):
            print(f"Model {note}_{method}_{size}_{subsub_dir} exists, skipping to next model.")
            continue

        # 3. if model does not exist
        # --- create the directory to store model
        os.makedirs(final_dir, exist_ok=True)
        log_info_with_note = f"Repetition {i}: Start training | method={note}_{method} | size={size} | key={subsub_dir} | "
        log_info_without_note = f"Repetition {i}: Start training | method={method} | size={size} | key={subsub_dir} | "
        print(log_info_with_note) if note else print(log_info_without_note)

        # track the time consumed
        ts = time.time()
        # --- begin training
        
        P_model = generate_param(param_key, method, size, span_model)
        trained_model, loss_history = train_model_single(train_key, pdeparams_phys, base_model, method, P_model, validation_chunk_size)
        # --- training finished, lkog and save time.
        print(f"Total time elapsed : {(time.time()-ts)/60:.2f} minutes")
        save_model(trained_model, model_path)
        save_pytree(loss_history, loss_path)
        print(f"Saved model and loss to: {final_dir}")


In [6]:
# test if compute_err is working properly
chunk_size = 10
params, static = eqx.partition(base_model, eqx.is_inexact_array)
ts = time.time()
compute_err(params, static, ref_data, chunk_size)
print(time.time()-ts)

0.5834283828735352


In [7]:
key = jax.random.PRNGKey(18696422)
out_dir = "models"
train_model_batch(key, pdeparams_phys, base_model, 'sobol', 10, note = "test_run", repetition = 3, out_dir = out_dir)

Model test_run_sobol_10_3eb9a7683ee2ea47_84a2fb5a191fbc74 exists, skipping to next model.
Model test_run_sobol_10_ab9434a46dd342d5_a20a34a02ec9faa6 exists, skipping to next model.
Model test_run_sobol_10_3f4628a3192bab30_163c93eb2e9d272a exists, skipping to next model.


In [8]:
key = jax.random.PRNGKey(44243546190052)
out_dir = "models"
train_model_batch(key, pdeparams_phys, base_model, 'uniformRandom', 10, note = "test_run", repetition = 3, out_dir = out_dir)

Model test_run_uniformRandom_10_79eb06323ba8667a_47c6367f83f30347 exists, skipping to next model.
Model test_run_uniformRandom_10_27142e2de57df7a8_c0f8aac4dc8fd0b5 exists, skipping to next model.
Repetition 2: Start training | method=test_run_uniformRandom | size=10 | key=990807eb8dfd677c_98d5639024d22581 | 
Step 0      total_loss = 14.9755     epoch_elapsed = 20.697s  total_elapsed = 20.697s
Keys   :      ic      |      bc      |      ac      |      ch      |     data     |    L2_phi    |     L2_c    
Weights: 1.3194e+01   | 1.2220e+01   | 2.1585e-04   | 2.3228e+00   | 1.2370e+01   | 1.0000e+00   | 1.0000e+00  
Losses : 4.0205e-01   | 4.3537e-01   | 1.8686e+02   | 2.6730e-02   | 2.6472e-01   | 5.0043e-01   | 4.7301e-01  
W * L  : 5.3047e+00   | 5.3204e+00   | 4.0333e-02   | 6.2088e-02   | 3.2746e+00   | 5.0043e-01   | 4.7301e-01  
Step 30000  total_loss = 0.2017      epoch_elapsed = 464.727s  total_elapsed = 485.424s
Keys   :      ic      |      bc      |      ac      |      ch      |

In [None]:
key = jax.random.PRNGKey(612565226762)
out_dir = "models"
train_model_batch(key, pdeparams_phys, base_model, 'grid', 10, note = "test_run", repetition = 3, out_dir = out_dir)

Repetition 0: Start training | method=test_run_grid | size=10 | key=d3e068992676a75b_17045d52f73300ee | 
Step 0      total_loss = 15.1060     epoch_elapsed = 15.422s  total_elapsed = 15.422s
Keys   :      ic      |      bc      |      ac      |      ch      |     data     |    L2_phi    |     L2_c    
Weights: 1.3134e+01   | 1.2171e+01   | 2.4300e-04   | 2.0899e+00   | 1.2318e+01   | 1.0000e+00   | 1.0000e+00  
Losses : 4.0317e-01   | 4.3639e-01   | 1.8493e+02   | 2.7504e-02   | 2.7793e-01   | 5.0043e-01   | 4.7299e-01  
W * L  : 5.2952e+00   | 5.3115e+00   | 4.4938e-02   | 5.7481e-02   | 3.4234e+00   | 5.0043e-01   | 4.7299e-01  
Step 30000  total_loss = 0.1832      epoch_elapsed = 476.516s  total_elapsed = 491.938s
Keys   :      ic      |      bc      |      ac      |      ch      |     data     |    L2_phi    |     L2_c    
Weights: 9.5556e+01   | 9.8884e+02   | 2.0712e-02   | 2.2651e-06   | 2.2559e+02   | 1.0000e+00   | 1.0000e+00  
Losses : 1.2961e-04   | 2.4276e-06   | 3.5403e-01

In [None]:
key = jax.random.PRNGKey(3463523192)
out_dir = "models"
train_model_batch(key, pdeparams_phys, base_model, 'gridInner', 10, note = "test_run", repetition = 3, out_dir = out_dir)

In [None]:
# model_path = 'models/test_run_sobol_5/366b71021d1e5b49_f5653be6ce868810/model_test_run_sobol_5_366b71021d1e5b49_f5653be6ce868810.pkl'
# model_test = load_model(base_model, model_path)
# model_test_params, model_test_static = eqx.partition(model_test, eqx.is_inexact_array)
# compute_err(model_test_params, model_test_static, ref_data, 5)

# jax.config.update("jax_enable_x64", True)
# pdeparams_phys = {
#     "alpha_phi": 9.62e-5, "omega_phi": 1.663e7, "M": 8.5e-10 / (2 * 5.35e7),
#     "A": 5.35e7, "L": 1e-11, "c_se": 1.0, "c_le": 5100/1.43e5, "x_range": (-50.0e-6, 50.0e-6),
#     "t_range": (0, 1.0e5), "nx": 128, "nt": 64, "l_0": 2*50.0e-6, "t_0": 1.0e5
# }

# pdedimless = PDE_dimless(pdeparams_phys)
# num_gt_to_use = pdeparams_phys['nx'] * pdeparams_phys['nt']
# span_pde = {'x':pdedimless.x_range_nd,'t':pdedimless.t_range_nd,'L':(1e-12, 1e-10),'M':(1e-21, 1e-19)}
# span_model = {'x':(-0.5,0.5), 't':(0,1),'L':(0,1),'M':(0,1)}
    
# suite = InteractivePDESuite(pdeparams_phys)
# suite.create_interactive_comparison_plot(
#     model= model_test,
#     span_pde=span_pde,
#     span_model=span_model,
#     num_frames=10,
#     prediction_color = 'red'
# )

In [None]:
# from get_NTK import *
# model = model_test
# train_with_data = True
# P_model = generate_param(hex_to_key('6f1dce3fa814fc8b'),'sobol', 5, span_model)

In [None]:
# jax.config.update("jax_enable_x64", True)
# inp_idx = model.inp_idx
# out_idx = model.out_idx
# span_pde = model.span_pde
# span_model = model.span_model
# P_phys = map_span_dict(P_model, span_model, span_pde)

# # --- generate training data
# pdedimless = PDE_dimless(pdeparams_phys)
# num_gt_to_use = pdeparams_phys['nx'] * pdeparams_phys['nt'] if train_with_data else 0
# key, subkey = random.split(key)
# _, train_data = pdedimless.generate_training_data(subkey, P_phys, num_gt_to_use)
# train_data = map_span_dict(train_data, span_pde, span_model)

# # --- convert everything to fp32
# jax.config.update("jax_enable_x64", False)
# P_model, P_phys, train_data, span_pde, span_model = map(
#     tree_to_f32,
#     (P_model, P_phys, train_data, span_pde, span_model)
# )

# # --- config derivative
# d = Derivative(inp_idx, out_idx, span_pde, span_model)
# for deriv_name in ['phi_t', 'phi_x', 'phi_2x', 'c_t', 'c_2x']:
#     d.create_deriv_fn(deriv_name)

# # --- construct residual
# r = Residual(span_pde, span_model, pdedimless, d)

# # --- config sampling function
# def input_single(key, P_i, train_data_i, num_data):
#     keys = random.split(key, 5)
#     inp = {k: {} for k in ['ic', 'bc', 'colloc', 'data']}
#     # ------ ic
#     inp['ic']['x'] = Sampler.get(keys[0], [16, 32, 16], [(-0.5, -0.1), (-0.1, 0.1), (0.1, 0.5)])
#     inp['ic']['t'] = Sampler.get(keys[0], [64], [(0, 0)])
#     # ------ bc
#     inp['bc']['x'] = Sampler.get(keys[1], [32, 32], [(-0.5, -0.5), (0.5, 0.5)])
#     inp['bc']['t'] = jnp.tile(Sampler.get(keys[1], [32], [(0, 1)]), 2)
#     # ------ colloc
#     x_colloc = Sampler.get(keys[2], [16], [(-0.5, 0.5)])
#     t_colloc = Sampler.get(keys[3], [16], [(0, 1)])
#     inp['colloc']['x'], inp['colloc']['t'] = map(lambda array: array.ravel(), jnp.meshgrid(x_colloc, t_colloc, indexing = "ij"))
#     # ------ data
#     inp['data'].update({key: train_data_i[key] for key in ['x','t','phi','c']})
        
#     # Add parameters to inp
#     for inp_key in inp:
#         ref_array = inp[inp_key]['x']
#         for param_key, param_value in P_i.items():
#             inp[inp_key][param_key] = jnp.full(ref_array.shape, param_value, ref_array.dtype)
            
#     return inp

# def new_input(key, P, train_data, num_data):
#     num_params = get_len(P)
#     keys = random.split(key, num_params)
#     out = jax.vmap(input_single, in_axes=(0, 0, 0, None))(keys, P, train_data, num_data)
#     return jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,)), out)

# def _subset(key, pytree, subsample_size):
#     return stratified_subset(key, pytree, get_len(pytree), get_len(P_model), subsample_size)

# # --- Helper function for sorting (place this before subset_input) ---
# def sort_points_dict(points_dict):
#     """
#     Sorts a dictionary of arrays based on a predefined key order: L, M, t, x.
#     """
#     # Define the desired sort order. This is the order of precedence.
#     sort_order = ['L', 'M', 't', 'x']

#     # Check if all necessary keys for sorting are present.
#     if not all(key in points_dict for key in sort_order):
#         # This is expected for the 'data' component if it doesn't have L, M.
#         # We return it unsorted as there's no basis for sorting.
#         return points_dict

#     # jnp.lexsort sorts by the LAST key in the tuple first (primary key).
#     # So, we provide the keys in reverse order of precedence.
#     lexsort_keys = tuple(points_dict[key] for key in reversed(sort_order))
#     sort_indices = jnp.lexsort(lexsort_keys)

#     # Apply these same indices to every array in the dictionary.
#     return tree_util.tree_map(lambda leaf: leaf[sort_indices], points_dict)


# # --- Your NEW, INTEGRATED subset_input function ---
# def subset_input(key, inp, subset_size, P_model):
#     """
#     Performs stratified subsampling and then re-sorts the results to ensure
#     a canonical order based on (L, M, t, x) for correct NTK plotting.

#     Args:
#         key (jax.random.PRNGKey): JAX random key.
#         inp (dict): The full input dictionary from new_input.
#         subset_size (dict): A dictionary specifying the number of points to sample
#                             for each component (e.g., {'ic': 16, 'colloc': 64}).
#         P_model (pytree): The pytree of model parameters (L, M), used to determine
#                           the number of strata for sampling.

#     Returns:
#         dict: The subsampled and correctly sorted input dictionary.
#     """
#     # Your existing helper function for stratified subsetting
#     def _subset(k, pytree, n_samples):
#         # Note: get_len is assumed to be defined in your environment
#         num_strata = get_len(P_model)
#         total_size = get_len(pytree)
#         return stratified_subset(k, pytree, total_size, num_strata, n_samples)

#     num_subsets = len(subset_size)
#     subkeys = random.split(key, num_subsets)

#     # 1. Perform the stratified subsampling (this shuffles the data)
#     subsampled_inp = {k: _subset(subkey, inp[k], size)
#                       for (k, size), subkey in zip(subset_size.items(), subkeys)}

#     # 2. Re-sort each component to restore the canonical parameter order
#     print("Re-sorting subsampled inputs to ensure correct parameter ordering for plotting...")
#     sorted_inp = {k: sort_points_dict(v) for k, v in subsampled_inp.items()}

#     return sorted_inp

In [None]:
# key = jax.random.PRNGKey(1314159261375)
# key, subkey = random.split(key)
# inp = new_input(key, P_model, train_data, num_gt_to_use)
# ntk_size = {'ic': 32, 'bc': 32, 'colloc': 128, 'data': 16*4}
#inp = subset_input(subkey, inp, ntk_size, P_model)

In [None]:
# NTK = get_NTK(
#     model_test, inp, r, P_model,
#     batch_size = 128, use_symlog = False,
#     plot_subgrids = "plot", out_dir = "plots", plot_name = "NTK_data.png",
#     dpi = 1200, selected_blocks = ["data"], ram_limit_gb = 8.0
# )