In [None]:
import jax
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt
from jax import grad, vmap, jit
import optax
from flax import linen as nn
from flax.traverse_util import flatten_dict
from jax.tree_util import tree_flatten

import numpy as np
import pickle

import math
from functools import partial
from typing import Union
from dataclasses import dataclass
from tqdm import trange
import sys
import os

# -------------------------------------------------------------------------
# 0) Visualization Helpers
# -------------------------------------------------------------------------
def _find_subplot_dims_approx_4_3(n: int):
    """
    Find (rows, cols) such that:
      1) rows * cols >= n
      2) ratio = cols / rows is as close as possible to 4/3
    """
    best_rc = (1, n)
    best_diff = float('inf')
    for r in range(1, n+1):
        c = math.ceil(n / r)
        if r * c >= n:
            ratio = c / r
            diff = abs(ratio - (4/3))
            if diff < best_diff:
                best_diff = diff
                best_rc = (r, c)
    return best_rc

def plot_dists(val_dict, color="C0", xlabel=None, stat="count", use_kde=True, log_scale=False):
    """
    Plot histograms for each array in val_dict on a grid whose aspect ratio 
    is as close to 4:3 as possible.

    :param val_dict: dict {name -> np.ndarray}
    :param color: plot color
    :param xlabel: x-axis label
    :param stat: 'count' or 'density'
    :param use_kde: whether to overlay a kernel density estimate
    """
    n_plots = len(val_dict)
    if n_plots == 0:
        print("No data to plot.")
        return None

    # Figure out a suitable grid arrangement
    rows, cols = _find_subplot_dims_approx_4_3(n_plots)

    # Create subplots
    fig, axes = plt.subplots(rows, cols, 
                             figsize=(3.0 * cols, 2.5 * rows),
                             squeeze=False)
    axes = axes.ravel()

    # Plot each array in val_dict
    for i, key in enumerate(val_dict.keys()):
        vals = val_dict[key].ravel()  # flatten
        if log_scale:
            vals = np.log1p(vals)
        ax = axes[i]
        # Use Seaborn for histogram
        sns.histplot(vals, ax=ax, color=color, bins=50, stat=stat,
                     kde=(use_kde and np.std(vals) > 1e-9), log_scale=log_scale)
        ax.set_title(f"{key}")
        print(f"{key}: Mean={np.mean(vals):.3e}, Std={np.std(vals):.3e}")
        if xlabel:
            ax.set_xlabel(xlabel)

    # If there are unused subplots (e.g. rows*cols > n_plots), hide them
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)

    fig.tight_layout()
    return fig

def plot_wave_pde_points(
    res: np.ndarray, 
    b_left: np.ndarray, 
    b_right: np.ndarray, 
    b_lower: np.ndarray, 
    b_upper: np.ndarray
):
    """
    Plots collocation points (res) and boundary points (b_left, b_right, 
    b_lower, b_upper) for a 1D wave PDE on [0,1]x[0,1].
    
    Parameters
    ----------
    res : (N,2) ndarray
        Interior (collocation) points, each row is (x, t).
    b_left : (N,2) ndarray
        Points where x=0.
    b_right : (N,2) ndarray
        Points where x=1.
    b_lower : (N,2) ndarray
        Points where t=0 (initial condition).
    b_upper : (N,2) ndarray
        Points where t=1.
    """
    fig, ax = plt.subplots(figsize=(6,6))

    # Collocation (interior) points
    ax.scatter(res[:,0], res[:,1], 
               c='k', alpha=0.3, s=8, label='Collocation (res)')

    # Boundary: x=0
    ax.scatter(b_left[:,0], b_left[:,1], 
               c='blue', s=20, label='Left boundary (x=0)')

    # Boundary: x=1
    ax.scatter(b_right[:,0], b_right[:,1], 
               c='red', s=20, label='Right boundary (x=1)')

    # Initial condition: t=0
    ax.scatter(b_lower[:,0], b_lower[:,1], 
               c='green', s=20, label='Initial condition (t=0)')

    # "Top" boundary: t=1
    ax.scatter(b_upper[:,0], b_upper[:,1], 
               c='magenta', s=20, label='t=1 boundary')

    ax.set_xlabel('x')
    ax.set_ylabel('t')
    ax.set_title('Wave PDE points: collocation & boundaries')
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.legend()
    plt.tight_layout()
    plt.show()
    return fig, ax

def flatten_intermediates(intdict):
    """
    Flatten a nested dictionary of "intermediates" returned 
    by model.apply(..., capture_intermediates=True, mutable=["intermediates"]).
    
    Returns a dict of {name -> np.ndarray}, 
    where 'name' is a path of keys joined by '/'.
    """
    all_acts = {}

    def recurse(prefix, d):
        if isinstance(d, dict):
            for k, v in d.items():
                recurse(prefix + [k], v)
        elif isinstance(d, (list, tuple)):
            for i, v2 in enumerate(d):
                recurse(prefix + [str(i)], v2)
        elif isinstance(d, jnp.ndarray):
            # Convert to numpy for plotting
            name = "/".join(prefix)
            all_acts[name] = np.array(d)
        else:
            # ignore non-array
            pass

    recurse([], intdict)
    return all_acts


def plot_losses(pinn):
    """
    Plot the training losses (total, residual, IC, BC) 
    logged in the MambaWavePINN instance.
    """
    plt.figure(figsize=(6,4))
    plt.plot(pinn.loss_log,     label="Total loss")
    plt.plot(pinn.loss_res_log, label="Residual loss")
    plt.plot(pinn.loss_ic_log,  label="IC loss")
    plt.plot(pinn.loss_bc_log,  label="BC loss")
    plt.yscale('log')
    plt.legend()
    plt.title("Losses per iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()


def plot_weight_distributions(pinn):
    """
    Flatten only the 'weight' parameters in pinn.params (where p.ndim>1)
    and plot their distributions with 'plot_dists' (already defined).
    """
    flat_params = tree_flatten(pinn.params)[0]  # list of leaves
    weight_vals = {}
    w_idx = 0
    for p in flat_params:
        if p.ndim > 1:  # treat anything with >1D as a 'weight'
            weight_vals[f"W_{w_idx}"] = np.array(p).ravel()
            w_idx += 1

    fig = plot_dists(weight_vals, color="C0", xlabel="Weight values")
    fig.suptitle("Weight Distribution", fontsize=14, y=1.03)
    plt.show()


def plot_gradient_distributions(pinn):
    """
    Compute PDE gradients w.r.t. pinn.params, flatten only 'weight' grads,
    and plot using 'plot_dists'.
    """
    grads = pinn.get_gradients(pinn.params)     # PyTree of gradients
    leaves = tree_flatten(grads)[0]             # list of leaves
    grad_vals = {}
    g_idx = 0
    for g in leaves:
        if g.ndim > 1:  # treat anything with >1D as 'weight grads'
            grad_vals[f"G_{g_idx}"] = np.array(g).ravel()
            g_idx += 1

    fig = plot_dists(grad_vals, color="C1", xlabel="Gradient values")
    fig.suptitle("Gradient Distribution", fontsize=14, y=1.03)
    plt.show()

def plot_activation_distributions(pinn, example_input):
    """
    Capture the intermediate activations of the MambaWavePINN forward pass,
    flatten them, and plot with 'plot_dists'.
    :param example_input: e.g. your PDE residual data 'res_seq_jnp' 
                          with shape (N,5,2) or something similar.
    """
    out, intermediates = pinn.get_intermediates(pinn.params, example_input)
    flat_acts = flatten_intermediates(intermediates)

    activations_dict = {}
    idx = 0
    for name, arr in flat_acts.items():
        # If you only want to see bigger shapes, e.g. arr.ndim > 1
        if arr.ndim > 1:
            activations_dict[f"Act_{idx}({name})"] = arr.ravel()
            idx += 1

    fig = plot_dists(activations_dict, color="C2", stat="density", xlabel="Activation values")
    fig.suptitle("Activation Distribution", fontsize=14, y=1.03)
    plt.show()

def plot_wave_solution(
    u_2d: np.ndarray,
    x_range: tuple = (0,1),
    t_range: tuple = (0,1),
    title: str = "Wave Solution",
    cmap: str = "viridis"
):
    """
    Plots a 2D wave solution array in (x,t) space.

    Parameters
    ----------
    u_2d : np.ndarray
        2D array with shape (n_t, n_x) or (n_x, n_t) 
        representing solution values u(x,t).

    x_range : (float, float)
        (min_x, max_x) to map the horizontal axis to [min_x, max_x].

    t_range : (float, float)
        (min_t, max_t) to map the vertical axis to [min_t, max_t].

    title : str
        Title for the plot.

    cmap : str
        Matplotlib colormap name for imshow.
    """
    plt.figure(figsize=(6,5))
    # If your array is (n_t, n_x), you might want to set `origin="lower"` 
    # so t=0 is at the bottom. Adjust as needed:
    plt.imshow(
        u_2d,
        extent=[x_range[0], x_range[1], t_range[0], t_range[1]],
        aspect='auto',
        cmap=cmap,
        origin="lower"  # ensures t=0 at the bottom if your array is (t,x)
    )
    plt.title(title)
    plt.xlabel("x")
    plt.ylabel("t")
    plt.colorbar(label="u(x,t)")
    plt.tight_layout()
    plt.show()

# -------------------------------------------------------------------------
# 1) MODEL ARGUMENTS
# -------------------------------------------------------------------------
@dataclass
class ModelArgs:
    d_model: int = 64
    n_layer: int = 2
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)

dense_kernel_init = nn.initializers.xavier_uniform()
dense_bias_init = nn.initializers.constant(0.00)

# -------------------------------------------------------------------------
# 2) HELPER MODULES (RMSNorm, MambaBlock, etc.)
# -------------------------------------------------------------------------
class RMSNorm(nn.Module):
    d_model: int
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        return x 

class MambaBlock(nn.Module):
    args: ModelArgs

    @nn.compact
    def __call__(self, x):
        # Implementation is unchanged from your snippet...
        args = self.args
        d_model = args.d_model
        d_inner = args.d_inner
        d_state = args.d_state
        dt_rank = args.dt_rank

        # Input Projection
        x_and_res = nn.Dense(
            features=d_inner * 2, use_bias=args.bias,
            kernel_init=dense_kernel_init, bias_init=dense_bias_init
        )(x)
        x_proj, res = jnp.split(x_and_res, 2, axis=-1)

        # Depthwise Convolution
        x_conv = x_proj
        x_conv = nn.Conv(
            features=d_inner,
            kernel_size=(args.d_conv,),
            feature_group_count=d_inner,
            padding='SAME',
            use_bias=args.conv_bias
        )(x_conv)
        x = nn.gelu(x_conv)

        # Compute delta, B, C
        x_dbl = nn.Dense(
            features=dt_rank + 2 * d_state * d_inner,
            use_bias=args.bias,
            kernel_init=dense_kernel_init,
            bias_init=dense_bias_init
        )(x)
        delta, B_C = jnp.split(x_dbl, [dt_rank], axis=-1)
        b, l, _ = delta.shape

        B_C = B_C.reshape(b, l, d_inner, 2*d_state)
        B, C = jnp.split(B_C, 2, axis=-1) 
        delta = nn.softplus(
            nn.Dense(d_inner, kernel_init=dense_kernel_init, bias_init=dense_bias_init)(delta)
        )

        # State Space Params
        A_log = self.param('A_log', nn.initializers.normal(), (d_inner, d_state))
        D = self.param('D', nn.initializers.ones, (d_inner,))
        A = -jnp.exp(A_log)

        y = self.ssm(x, delta, A, B, C, D)
        y = y * nn.gelu(res)
        output = nn.Dense(
            features=d_model, use_bias=args.bias,
            kernel_init=dense_kernel_init, bias_init=dense_bias_init
        )(y)
        return output

    def ssm(self, x, delta, A, B, C, D):
        """
        This is your state-space scan logic (unchanged).
        """
        b, l, d = x.shape
        n = A.shape[1]
        deltaA = jnp.exp(jnp.einsum('bld,dn->bldn', delta, A))
        deltaB_u = jnp.einsum('bld,bldn,bld->bldn', delta, B, x)

        x_state = jnp.zeros((b, d, n))

        def scan_fn(carry, inputs):
            x_prev = carry
            deltaA_t, deltaB_u_t = inputs
            x_t = deltaA_t * x_prev + deltaB_u_t
            return x_t, x_t

        deltaA_list = deltaA.transpose(1, 0, 2, 3)   # => (l,b,d,n)
        deltaB_u_list = deltaB_u.transpose(1, 0, 2, 3)
        _, x_states = jax.lax.scan(scan_fn, x_state, (deltaA_list, deltaB_u_list))
        x_states = x_states.transpose(1, 0, 2, 3)  # => (b,l,d,n)

        y = jnp.einsum('bldn,bldn->bld', x_states, C)
        y = y + x * D[None, None, :]
        return y

class ResidualBlock(nn.Module):
    args: ModelArgs

    @nn.compact
    def __call__(self, x):
        x_norm = RMSNorm(self.args.d_model)(x)
        y = MambaBlock(self.args)(x_norm)
        return x + y

class Mamba(nn.Module):
    args: ModelArgs

    @nn.compact
    def __call__(self, x):
        # x: (batch_size, seq_len, 2)
        x = nn.Dense(features=self.args.d_model)(x)
        for _ in range(self.args.n_layer):
            x = ResidualBlock(self.args)(x)
        x = RMSNorm(self.args.d_model)(x)
        logits = nn.Dense(features=1, kernel_init=dense_kernel_init, bias_init=dense_bias_init)(x)
        return logits  # => (batch_size, seq_len, 1)

# -------------------------------------------------------------------------
# 3) MambaWavePINN (Slightly Extended for Visualization)
# -------------------------------------------------------------------------
class MambaWavePINN:
    def __init__(
        self,
        key,
        args: ModelArgs,
        x_seq_res: jnp.ndarray,
        t_seq_res: jnp.ndarray,
        x_seq_ic: jnp.ndarray,
        t_seq_ic: jnp.ndarray,
        x_seq_bc: jnp.ndarray,
        t_seq_bc: jnp.ndarray,
        learning_rate=1e-3,
    ):
        self.model = Mamba(args)
        self.args = args
        self.key = key
        # Initialize params with a dummy input => shape (1, 5, 2)
        dummy_input = jnp.ones((1, 5, 2))
        self.params = self.model.init(self.key, dummy_input)

        # Store PDE data
        self.x_seq_res = x_seq_res
        self.t_seq_res = t_seq_res
        self.x_seq_ic  = x_seq_ic
        self.t_seq_ic  = t_seq_ic
        self.x_seq_bc  = x_seq_bc
        self.t_seq_bc  = t_seq_bc

        self.optimizer = optax.adam(learning_rate)
        self.opt_state = self.optimizer.init(self.params)

        # Logs
        self.loss_log = []
        self.loss_ic_log = []
        self.loss_bc_log = []
        self.loss_res_log = []

        self.print_model_summary()

    def neural_net(self, params, t_seq, x_seq):
        """ Evaluate Mamba model on (t_seq, x_seq) => shape (batch_size, seq_len, 1). """
        tx_inputs = jnp.concatenate([t_seq, x_seq], axis=-1)  # => (B,L,2)
        return self.model.apply(params, tx_inputs)

    def residual_net(self, params, t_seq, x_seq):
        """ PDE residual for wave eqn: u_tt - 4u_xx = 0. """
        t_flat = t_seq.reshape(-1)
        x_flat = x_seq.reshape(-1)

        def u_fn(t, x):
            inp = jnp.array([[[t, x]]])  # shape (1,1,2)
            out = self.model.apply(params, inp)  # shape (1,1,1)
            return out[0,0,0]

        def u_t_scalar(t, x):
            return grad(u_fn, argnums=0)(t, x)

        def u_tt_scalar(t, x):
            return grad(lambda T: u_t_scalar(T, x))(t)

        def u_x_scalar(t, x):
            return grad(u_fn, argnums=1)(t, x)

        def u_xx_scalar(t, x):
            return grad(lambda X: u_x_scalar(t, X))(x)

        vmap_u_tt = vmap(u_tt_scalar, in_axes=(0, 0))
        vmap_u_xx = vmap(u_xx_scalar, in_axes=(0, 0))
        u_tt_vals = vmap_u_tt(t_flat, x_flat)
        u_xx_vals = vmap_u_xx(t_flat, x_flat)
        r_vals = u_tt_vals - 4.0*u_xx_vals
        return r_vals.reshape(t_seq.shape)

    @partial(jit, static_argnums=0)
    def loss_fn(self, params):
        # 1) PDE residual
        r_vals = self.residual_net(params, self.t_seq_res, self.x_seq_res)
        loss_res = jnp.mean(r_vals**2)

        # 2) Initial condition
        u_ic_pred = self.neural_net(params, self.t_seq_ic, self.x_seq_ic)  # => (B,L,1)
        x_ic_flat = self.x_seq_ic.reshape(-1)
        u_ic_true = jnp.sin(jnp.pi*x_ic_flat) - 0.5*jnp.sin(3.0*jnp.pi*x_ic_flat)
        u_ic_true = u_ic_true.reshape(u_ic_pred.shape[0], u_ic_pred.shape[1])
        loss_ic_1 = jnp.mean((u_ic_pred.squeeze(-1) - u_ic_true)**2)

        # Enforce partial u/partial t = 0 at t=0
        def scalar_u(t, x):
            inp = jnp.array([[[t,x]]])
            return self.model.apply(params, inp)[0,0,0]
        u_t0 = vmap(lambda xx: grad(scalar_u, argnums=0)(0.0, xx))(x_ic_flat)
        loss_ic_2 = jnp.mean(u_t0**2)
        loss_ic = loss_ic_1 + loss_ic_2

        # 3) Boundary condition => u=0 at x=0 or x=1
        u_bc_pred = self.neural_net(params, self.t_seq_bc, self.x_seq_bc)
        loss_bc = jnp.mean(u_bc_pred**2)

        loss_total = loss_res + loss_ic + loss_bc
        return loss_total, (loss_res, loss_ic, loss_bc)

    @partial(jit, static_argnums=0)
    def update_step(self, params, opt_state):
        (loss_val, (loss_res, loss_ic, loss_bc)), grads = jax.value_and_grad(
            self.loss_fn, has_aux=True
        )(params)
        updates, opt_state = self.optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_val, loss_res, loss_ic, loss_bc

    def train(self, n_iter=5000, print_every=100):
        params = self.params
        opt_state = self.opt_state
        for i in range(n_iter):
            params, opt_state, loss_val, l_res, l_ic, l_bc = self.update_step(params, opt_state)
            if i % print_every == 0:
                self.loss_log.append(float(loss_val))
                self.loss_res_log.append(float(l_res))
                self.loss_ic_log.append(float(l_ic))
                self.loss_bc_log.append(float(l_bc))
                print(f"[Iter {i}] Loss={loss_val:.3e} Res={l_res:.3e} IC={l_ic:.3e} BC={l_bc:.3e}")
        self.params = params
        self.opt_state = opt_state

    def predict(self, t_seq, x_seq):
        return self.neural_net(self.params, t_seq, x_seq)

    def print_model_summary(self):
        flat_params = flatten_dict(self.params)
        total_params = 0
        print("Model Summary (MambaWavePINN):")
        for path, param in flat_params.items():
            print(f"{'/'.join(path)}: shape={param.shape}, size={param.size}")
            total_params += param.size
        print(f"Total parameters: {total_params}")

    # ---------------------------------------------------------------------
    # (A) Capture Activations
    # ---------------------------------------------------------------------
    def get_intermediates(self, params, example_seq):
        """
        Returns (final_output, intermediates_dict).
        The intermediates_dict has structure:
          intermediates_dict["intermediates"][module_name] = { ... outputs ... }
        """
        # 'mutable=["intermediates"]' + 'capture_intermediates=True'
        # let us see all intermediate activations from the forward pass.
        out, info = self.model.apply(
            params, example_seq,
            capture_intermediates=True,
            mutable=["intermediates"]
        )
        intermediates = info["intermediates"]
        return out, intermediates

    # ---------------------------------------------------------------------
    # (B) Compute Gradients w.r.t. PDE loss
    # ---------------------------------------------------------------------
    def get_gradients(self, params):
        """
        Return the gradient PyTree (flattened into a list/dict) 
        for the PDE loss.
        """
        def loss_only(p):
            return self.loss_fn(p)[0]  # PDE scalar loss
        grads = jax.grad(loss_only)(params)
        return grads


# -------------------------------------------------------------------------
# 4) Example Usage & Visualization
# -------------------------------------------------------------------------

def make_time_sequence(src, num_step=5, step=1e-4):
    """
    For each row of src (which is (x, t)), replicate across seq_len=num_step
    and increment t by 'step*i'. => shape (N, num_step, 2).
    NOTE: Here we assume src has shape (N,2) => (x_i, t_i).
    If your ordering is (t,x), adapt accordingly.
    """
    N = src.shape[0]
    seq = np.tile(src[:, None, :], (1, num_step, 1))
    for i in range(num_step):
        seq[:, i, 1] += step*i  # increment the t-component if it's the 2nd col
    return seq

def get_data(x_range, y_range, x_num, y_num):
    x = np.linspace(x_range[0], x_range[1], x_num)
    t = np.linspace(y_range[0], y_range[1], y_num)
    x_mesh, t_mesh = np.meshgrid(x, t)
    data = np.stack([x_mesh.flatten(), t_mesh.flatten()], axis=-1)  # => (N,2), each row (x,t)

    b_left = data[np.isclose(x_mesh.flatten(), x_range[0])]
    b_right= data[np.isclose(x_mesh.flatten(), x_range[1])]
    b_upper= data[np.isclose(t_mesh.flatten(), y_range[1])]
    b_lower= data[np.isclose(t_mesh.flatten(), y_range[0])]

    return data, b_left, b_right, b_upper, b_lower

def get_data_wave_1d(x_range, t_range, x_num, t_num):
    res, b_left, b_right, b_upper, b_lower = get_data(x_range, t_range, x_num, t_num)
    return res, b_left, b_right, b_upper, b_lower

# exact solution
def wave_exact_solution(x, t):
    """
    Example wave eqn solution:
      u(x,t) = sin(pi*x)*cos(2*pi*t) + 0.5*sin(3*pi*x)*cos(6*pi*t)
    """
    return np.sin(np.pi*x)*np.cos(2.0*np.pi*t) + 0.5*np.sin(3.0*np.pi*x)*np.cos(6.0*np.pi*t)



# 1) Create data and visualize PDE points
res, b_left, b_right, b_upper, b_lower = get_data_wave_1d([0,1],[0,1], 31,31)
plot_wave_pde_points(res, b_left, b_right, b_lower, b_upper)

# 2) Make time sequences
# PDE residual data
res_seq = make_time_sequence(res, 5, 1e-4)  # shape (N,5,2)
res_seq_jnp = jnp.array(res_seq)
x_seq_res = res_seq_jnp[:,:,0:1]  # (N,5,1)
t_seq_res = res_seq_jnp[:,:,1:2]  # (N,5,1)

# Initial condition (t=0 => b_lower)
b_lower_seq = make_time_sequence(b_lower, 1, 0.0)
b_lower_seq_jnp = jnp.array(b_lower_seq)
x_seq_ic = b_lower_seq_jnp[:,:,0:1]
t_seq_ic = b_lower_seq_jnp[:,:,1:2]

# Boundary condition (x=0 or 1 => b_left + b_right)
b_bc = np.concatenate([b_left, b_right], axis=0)
b_bc_seq = make_time_sequence(b_bc, 5, 1e-4)
b_bc_seq_jnp = jnp.array(b_bc_seq)
x_seq_bc = b_bc_seq_jnp[:,:,0:1]
t_seq_bc = b_bc_seq_jnp[:,:,1:2]


# 2) Initialize the model
key = jax.random.PRNGKey(0)
args = ModelArgs(d_model=8, n_layer=2, d_state=8, expand=2, dt_rank='auto')
pinn = MambaWavePINN(
    key=key,
    args=args,
    x_seq_res=x_seq_res, t_seq_res=t_seq_res,
    x_seq_ic=x_seq_ic,   t_seq_ic=t_seq_ic,
    x_seq_bc=x_seq_bc,   t_seq_bc=t_seq_bc,
    learning_rate=1e-3
)

u_exact = wave_exact_solution(t_seq_res[:,0:1,:], x_seq_res[:,0:1,:]).reshape(31,31)
# Suppose u_exact.shape = (31, 31), with t along axis 0 and x along axis 1
plot_wave_solution(u_exact, x_range=(0,1), t_range=(0,1), title="Exact Wave Solution")


u_pred = pinn.predict(t_seq_res[:,0:1,:], x_seq_res[:,0:1,:]).reshape(31,31)
plot_wave_solution(u_pred, x_range=(0,1), t_range=(0,1), title="Predicted Wave Solution")


plot_weight_distributions(pinn)
plot_gradient_distributions(pinn)
plot_activation_distributions(pinn, res_seq_jnp)

# Train a bit
pinn.train(n_iter=2000, print_every=50)

# === Plot in the requested order ===

plot_weight_distributions(pinn)
plot_gradient_distributions(pinn)
plot_activation_distributions(pinn, res_seq_jnp)

plot_losses(pinn)
u_pred = pinn.predict(t_seq_res[:,0:1,:], x_seq_res[:,0:1,:]).reshape(31,31)
plot_wave_solution(u_pred, x_range=(0,1), t_range=(0,1), title="Predicted Wave Solution")

print("Done.")