# **Applying KAN-PIN to the Marmousi Velocity Model**  

For the Marmousi velocity model, where velocity variations are highly complex, KANs might overfit to noise or spurious correlations. 
By combining Physics Informed Neural Network, we can mitigate this by ensuring that the learned representations align with expected physical behaviors in ray tracing, improving generalization and stability. 
This approach allows KANs to learn meaningful solutions that are both data-efficient and physically valid.

## **Setup**

The following cells configure the environment for the experiment. This includes:
*   **Loading extensions and libraries**: Essential for data manipulation, deep learning, and visualization.
*   **Device configuration**: Sets the computation device to CUDA if available, otherwise defaults to CPU.
*   **Path definitions**: Specifies the paths for data, output, and model checkpoints.

In [2]:
%load_ext autoreload
%autoreload 2

import os
import random
from itertools import cycle, product
from functools import partial
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import plotly.graph_objects as go

from tqdm.auto import tqdm
from kan import KAN

import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Sampler
from utils.architecture import Architecture
from rt_python import DataGeneratorMarmousi
from utils.metrics import score

SEED = 42

  from .autonotebook import tqdm as notebook_tqdm


Use CUDA device when available. Otherwise, use CPU device.

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")  # Force CPU for this example
print(f"Using '{DEVICE}' device.")

Setup your path to data, output and checkpoint folders. 

In [None]:
DATA_PATH = Path("../data/best")
OUTPUT_PATH = Path("../output/best")
CHECKPOINT_PATH = OUTPUT_PATH/"model"

model_save_path = "../output/best/best.pt"

## **Data Acquisition**

The dataset used in this study is derived from the Marmousi velocity model, a well-known synthetic benchmark in geophysics. This model represents a highly complex subsurface with strong velocity variations, making it ideal for testing ray-tracing algorithms and machine-learning approaches in seismic imaging.

To improve numerical stability and facilitate more efficient training, we normalize the velocity values by converting them to **km/s**. This transformation mitigates scale discrepancies (e.g., values ranging from 10 to 10,000), ensuring better gradient propagation and more stable optimization during model training.

In [None]:
# Model parameters
nx = 2301  # Number of samples in the distance dimension
nz = 751   # Number of samples in the depth dimension
dz = 4     # Distance increment (m)
dx = 4     # Depth increment (m)

# Model limits in km
xmax = nx * dx / 1000
zmax = nz * dz / 1000

vp_file = "../data/marmousi_vp.bin"

vp = np.fromfile(vp_file, dtype=np.dtype('float32').newbyteorder('<'))
vp = vp.reshape((nx, nz)).transpose() / 1000  # Converting to km/s
vp.shape

Visualizing the Marmousi Velocity Model with Matplotlib

In [None]:
plt.figure(figsize=(18, 6))
plt.imshow(vp, extent=[0, xmax, zmax, 0], aspect='auto', cmap='viridis')
plt.colorbar(label="Velocity (km/s)")
plt.title("Marmousi - Vp")
plt.xlabel("x (km)")
plt.ylabel("z (km)")
plt.show()

### **Ray-Tracing Data Generation and Visualization Using `DataGeneratorMarmousi`**

The **`DataGeneratorMarmousi`** class systematically generates ray-tracing data within the **Marmousi velocity model**, enabling machine learning-driven approaches in seismic modeling. By leveraging B-spline interpolation and numerical integration, it constructs high-quality training datasets that capture wave propagation dynamics. The class extends `DataGenerator`, inheriting core spatial attributes without redefining the constructor. Its key method, `run_multiple`, generates ray trajectories across a predefined range of initial positions `(x0, z0)` and angles `(θ0)`, ensuring comprehensive sampling of the velocity model. This process involves interpolating the velocity field for smooth representation, numerically solving ray equations, and aggregating the computed paths into a structured **DataFrame**. The resulting dataset is optimized for training **Kolmogorov-Arnold Networks (KANs)** and other ML models, providing a physics-informed foundation for learning seismic wave behavior while maintaining computational efficiency.  

The code below initializes an instance of `DataGeneratorMarmousi` and generates a dataset of ray-tracing paths within the **Marmousi velocity model**. The process consists of three main steps:  

1. **Defining the Spatial Domain**:  
   The `x_range` and `z_range` parameters define the horizontal and depth extents of the velocity model, ranging from `0` to `xmax` and `0` to `zmax`, respectively. These boundaries ensure that rays are traced within the predefined seismic model.  

2. **Generating Ray-Tracing Data (`run_multiple`)**:  
   The method `run_multiple` is called to compute multiple ray trajectories, systematically sampling different initial conditions:
   - **`x0_range=(4, 6)` and `z0_range=(1, 2)`**: Specifies that rays originate from positions within these spatial bounds.  
   - **`theta_range=(45, 75)`**: Defines the range of initial propagation angles in degrees.  
   - **`vp=vp`**: Passes the Marmousi velocity model as input.  
   - **`factor=30`**: Controls the downsampling of the velocity field for computational efficiency.  
   - **`dx_dy=0.1`**: Sets the resolution for sampling initial positions.  

   The method iterates over all combinations of `(x0, z0, θ0)`, computes the corresponding ray trajectories, and stores the results in a structured **DataFrame** (`df`). This dataset includes the spatial evolution of each ray over time, along with its velocity and direction of propagation.  

3. **Visualization (`plot`)**:  
   The `plot` method visualizes the generated ray paths overlaid on the velocity model. The **velocity field** is represented as a colormap, while the **ray trajectories** illustrate how seismic waves travel through the subsurface. The figure size `(22,6)` ensures clarity, and `plt.show()` displays the plot.  

This workflow effectively generates a diverse dataset for **machine learning applications**, such as training **Kolmogorov-Arnold Networks (KANs)** to learn seismic wave propagation patterns while maintaining physical consistency.


In [None]:
x_range = (0, xmax)
z_range = (0, zmax)

data_gen = DataGeneratorMarmousi(
    x_range=x_range,
    z_range=z_range
)
df = data_gen.run_multiple(x0_range=(4, 6),
                           z0_range=(1, 2),
                           theta_range=(45, 75),
                           vp=vp,
                           factor=30,
                           dx_dy=0.1,
                           dtheta=5,
                           t_max=0.4,
                           )
fig = data_gen.plot(df, figsize=(22, 6))
plt.show()

Here, we define the feature groups for our architecture: **KAN features** (inputs to the Kolmogorov-Arnold Network), **Arch features** (KAN features plus additional features used in the PIN module), and **target features** (the predicted output). This organization ensures a clear and structured data flow within the model.

In [None]:
kan_features = ['x0', 'z0', 'theta0_p', 't']
arch_features = kan_features + ['pi_weight', 'dxdt', 'dzdt', 'dpxdt', 'dpzdt']
target = ['x', 'z', 'px', 'pz']

### **Weights in PINN Loss Function**  

In **Physics-Informed Neural Networks (PINNs)**, loss weighting plays a crucial role in balancing data-driven learning with physical constraints. The weights used in the PINN loss function dynamically adjust the contribution of data and physics-based constraints. In regions with dense data points, the model prioritizes data fidelity, while in sparsely sampled areas, it emphasizes adherence to the governing equations. This adaptive weighting ensures a balanced learning process, improving generalization across the entire domain.  

Here, the **weighting strategy is based on spatial frequency**. The code computes the number of data points within predefined spatial regions (squares), assigning higher weights to underrepresented areas and lower weights to densely populated regions. This **adaptive weighting** prevents the model from overfitting high-density regions and helps it generalize across the entire domain.  

By incorporating these weights into the PINN loss function, the model effectively learns from both observed data and physics-based constraints, leading to improved stability, accuracy, and robustness in solving inverse problems in seismic modeling.

In [None]:
def get_squares_limits(data: pd.DataFrame, restrictions: dict, step: float) -> np.ndarray:
    """
    Generates a set of square limits based on specified feature restrictions and step size.

    This function partitions the feature space into discrete intervals based on given restrictions,
    creating a grid of square regions for further analysis.

    Args:
        data (pd.DataFrame): The input dataset containing feature columns.
        restrictions (dict): A dictionary defining the range of each feature. 
            Each key corresponds to a feature name, and its value is a dictionary 
            with 'min' and 'max' keys specifying the range.
        step (float): The step size used to discretize the feature space.

    Returns:
        np.ndarray: A NumPy array containing all possible square intervals formed by the 
        specified feature restrictions.
    
    Raises:
        AssertionError: If any feature in `restrictions` is not present in the dataset.

    Example:
        >>> data = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
        >>> restrictions = {'x': {'min': 0, 'max': 3}, 'y': {'min': 4, 'max': 6}}
        >>> get_squares_limits(data, restrictions, step=1)
        array([[[0., 1.], [4., 5.]],
               [[0., 1.], [5., 6.]],
               [[1., 2.], [4., 5.]],
               [[1., 2.], [5., 6.]],
               [[2., 3.], [4., 5.]],
               [[2., 3.], [5., 6.]]], dtype=float32)
    """
    assert all([feature in data.columns for feature in restrictions.keys()]), \
        "Some features presented in restrictions are not in the data."

    # Generate the intervals for each feature
    limits_map = {}
    for name, boundary in restrictions.items():
        aux = np.arange(boundary['min'], boundary['max'] + step, step=step, dtype='float32')
        aux = [round(x, 3) for x in aux]
        limits_map[name] = [(aux[i], aux[i+1]) for i in range(len(aux) - 1)]

    # Create all combinations of the intervals between features
    combinations = list(product(*limits_map.values()))

    result = [np.array(combination, dtype='float32') for combination in combinations]

    return np.array(result, dtype='float32')

In [None]:
def get_frequency(data: pd.DataFrame, restrictions: dict, step: float = 0.1) -> pd.DataFrame:
    """
    Computes the frequency of data points within predefined square regions in the feature space.

    This function divides the input data into grid-based square regions and counts the number 
    of data points falling within each region. The result is stored in a DataFrame.

    Args:
        data (pd.DataFrame): The dataset containing the features to be analyzed.
        restrictions (dict): A dictionary defining feature-wise range restrictions.
            Each key corresponds to a feature name, with 'min' and 'max' specifying the range.
        step (float, optional): The step size used to define square regions. Default is 0.1.

    Returns:
        pd.DataFrame: A DataFrame containing each square's limits and the corresponding data point count.

    Example:
        >>> data = pd.DataFrame({'x': [0.5, 1.5, 2.5], 'y': [4.5, 5.5, 6.5]})
        >>> restrictions = {'x': {'min': 0, 'max': 3}, 'y': {'min': 4, 'max': 6}}
        >>> get_frequency(data, restrictions, step=1)
           square   frequency
        0  [[0, 1], [4, 5]]  1
        1  [[0, 1], [5, 6]]  0
        2  [[1, 2], [4, 5]]  0
        3  [[1, 2], [5, 6]]  1
        4  [[2, 3], [4, 5]]  0
        5  [[2, 3], [5, 6]]  1
    """
    squares_limits = get_squares_limits(data, restrictions, step)

    # Count points in each square
    frequencies = []
    for square in squares_limits:
        mask = np.ones(len(data), dtype=bool)
        for feature, limits in zip(restrictions.keys(), square):
            sqr_min, sqr_max = limits
            mask &= (data[feature] >= sqr_min) & (data[feature] < sqr_max)

        # Count points inside this square
        frequencies.append(np.sum(mask))

    # Prepare result as a DataFrame
    frequency_df = pd.DataFrame(
        data={
            "square": list(squares_limits),
            "frequency": frequencies
        }
    )

    return frequency_df


In [None]:
def add_frequency_to_data(data: pd.DataFrame, frequency_df: pd.DataFrame) -> pd.DataFrame:
    """
    Adds frequency information to the dataset based on predefined spatial regions.

    This function assigns a frequency value to each data point by matching it to the 
    corresponding region (square) defined in `frequency_df`. The frequency represents 
    the number of data points found in that region, ensuring that each sample is 
    weighted accordingly for further processing.

    Args:
        data (pd.DataFrame): The input dataset containing feature columns.
        frequency_df (pd.DataFrame): A DataFrame with frequency counts for different 
            spatial regions, where each row contains:
            - "square": A list of tuples defining the boundaries of the region.
            - "frequency": The number of data points within that region.

    Returns:
        pd.DataFrame: A copy of `data` with an added "frequency" column.

    Example:
        >>> data = pd.DataFrame({'x': [0.5, 1.5, 2.5], 'y': [4.5, 5.5, 6.5]})
        >>> frequency_df = pd.DataFrame({
        ...     "square": [[(0,1), (4,5)], [(1,2), (5,6)], [(2,3), (6,7)]],
        ...     "frequency": [10, 5, 3]
        ... })
        >>> add_frequency_to_data(data, frequency_df)
             x    y  frequency
        0  0.5  4.5        10
        1  1.5  5.5         5
        2  2.5  6.5         3
    """
    # Initialize an array to store frequency values
    frequencies = np.zeros(len(data), dtype=int)

    # Iterate over each row in the frequency DataFrame to assign frequencies
    for _, square_row in frequency_df.iterrows():
        square = square_row['square']
        frequency = square_row['frequency']

        # Build a mask to filter the rows in `data` that fall within the current square
        mask = np.ones(len(data), dtype=bool)
        for feature, (min_val, max_val) in zip(data.columns, square):
            mask &= (data[feature] >= min_val) & (data[feature] < max_val)

        # Assign the frequency value to the matching rows
        frequencies[mask] = frequency

    # Create a copy of the dataset and add the frequency column
    data_with_frequency = data.copy()
    data_with_frequency['frequency'] = frequencies

    return data_with_frequency


In [None]:
def plot_surface(data: pd.DataFrame, filename: str = "output/surface_plot.html") -> None:
    """
    Generates a 3D surface plot of point frequencies within spatial regions.

    This function visualizes the frequency distribution of data points across 
    predefined grid regions. It extracts midpoints of the spatial squares and 
    maps the corresponding frequencies, creating a structured 3D surface plot 
    using Plotly.

    Args:
        data (pd.DataFrame): A DataFrame containing:
            - "square": A list of tuples defining the spatial region boundaries.
            - "frequency": The count of data points within each square.
        filename (str, optional): The output file path to save the interactive 
            HTML plot. Default is `"output/surface_plot.html"`.

    Returns:
        None: The function generates and saves the plot but does not return a value.

    Example:
        >>> data = pd.DataFrame({
        ...     "square": [[(0,1), (4,5)], [(1,2), (5,6)], [(2,3), (6,7)]],
        ...     "frequency": [10, 5, 3]
        ... })
        >>> plot_surface(data, filename="surface_plot.html")
    """
    # Extract midpoints and corresponding frequencies
    squares = data['square'].tolist()  # Assuming 'square' column stores lists
    frequencies = data['frequency'].values

    # Compute midpoints for visualization
    x = np.array([(interval[0][0] + interval[0][1]) / 2 for interval in squares])
    y = np.array([(interval[1][0] + interval[1][1]) / 2 for interval in squares])
    z = frequencies

    # Create a structured grid for the surface plot
    unique_x = np.unique(x)
    unique_y = np.unique(y)
    X, Z = np.meshgrid(unique_x, unique_y)

    # Map frequency values to the grid
    freq = np.zeros_like(X)
    for i, x_val in enumerate(unique_x):
        for j, y_val in enumerate(unique_y):
            mask = (x == x_val) & (y == y_val)
            if np.any(mask):
                freq[j, i] = z[mask][0]  # Assign the frequency to the grid point

    # Create 3D surface plot
    fig = go.Figure()
    fig.add_trace(go.Surface(z=freq, x=X, y=Z, opacity=0.8))

    # Customize layout
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Z',
            zaxis_title='Frequency'
        ),
        title='Surface Plot of Points in Squares'
    )

    # Save plot as an interactive HTML file
    fig.write_html(filename)


### **Splitting the Dataset into Training, Validation, and Test Sets**  

Proper dataset partitioning is essential to ensure robust model evaluation and generalization. Here, we divide the data into **training, validation, and test sets**, ensuring that the model learns effectively while being evaluated on unseen data. The **training set** is used to optimize model parameters, the **validation set** helps fine-tune hyperparameters and prevent overfitting, and the **test set** provides an unbiased estimate of final model performance.  


In [None]:
df.shape

In [None]:
initial_conditions = df[['x0', 'z0', 'theta0_p']].drop_duplicates()
aux = initial_conditions.copy()

ic_train = aux.sample(frac=0.6, random_state=SEED)
ic_val = aux.drop(ic_train.index).sample(frac=0.5, random_state=SEED)
ic_test = aux.drop(list(ic_train.index) + list(ic_val.index))

print("Train Initial Conditions:", len(ic_train))
print("Validation Initial Conditions:", len(ic_val))
print("Test Initial Conditions:", len(ic_test))

In [None]:
initial_conditions.shape

A total of **401 points** are used to construct the ray trajectory, covering a time range from **0 to 400 ms** with a time step (**dt**) of **1 ms**. This high-resolution sampling ensures precise tracking of wave propagation dynamics, improving the accuracy of trajectory estimation.  

In [None]:
df_train = df.merge(ic_train, on=['x0', 'z0', 'theta0_p'], how='inner')
df_val = df.merge(ic_val, on=['x0', 'z0', 'theta0_p'], how='inner')
df_test = df.merge(ic_test, on=['x0', 'z0', 'theta0_p'], how='inner')

print("Train Size:", len(df_train))
print("Validation Size:", len(df_val))
print("Test Size:", len(df_test))

### Saving the Datasets

Saving the generated datasets ensures future reproducibility and transparency in scientific workflows.

In [None]:
df_train.to_csv(DATA_PATH/'train_marmousi_f30.csv', index=False)
df_val.to_csv(DATA_PATH/'val_marmousi_f30.csv', index=False)
df_test.to_csv(DATA_PATH/'test_marmousi_f30.csv', index=False)

### Reading the Datasets

After saving the datasets, let's load them so we can skip regenerating and preprocessing later.

In [None]:
df_train = pd.read_csv(DATA_PATH/'train_marmousi_f30.csv')
df_val = pd.read_csv(DATA_PATH/'val_marmousi_f30.csv')
df_test = pd.read_csv(DATA_PATH/'test_marmousi_f30.csv')

### Plotting the frequencies

We are balancing learning from two sources: the actual data points (data loss) and the governing physical equations (physics loss). However, the data generated from processes like ray tracing is often not uniformly distributed across the entire domain. Some areas may be densely sampled with many ray paths crossing them, while others might be sparsely sampled.

**Why is this a problem?**
If we treat all points equally, the model's training will be dominated by the high-density regions. It will learn to be very accurate in those areas, but may neglect the sparsely populated regions, leading to poor generalization.

**How do we fix it?**
We introduce a weighting scheme for the physics loss that is inversely proportional to the data frequency.
-   **High-Frequency Regions (many data points):** These regions receive a *lower* weight. The model can rely more on the abundant data to learn the correct behavior.
-   **Low-Frequency Regions (few data points):** These regions receive a *higher* weight. This forces the model to pay closer attention to the governing physics (the ray-tracing equations) to make accurate predictions, compensating for the lack of data.

**The Role of the Plot**
By plotting the frequency as a 3D surface, we get an immediate visual understanding of our data distribution. This plot helps us:
1.  **Verify Data Coverage:** Quickly identify which parts of the velocity model are well-sampled and which are not.
2.  **Understand the Weights:** Intuitively grasp how the physics loss will be weighted across the domain. The "valleys" in the plot correspond to areas where the physics loss will have the most impact.
3.  **Debug Data Generation:** Unexpected gaps or patterns in the frequency plot can signal issues in the data generation process itself.

In short, plotting the frequency is a key diagnostic step that validates our strategy for creating a more robust and accurate physics-informed model.

In [None]:
restrictions = {
    'x': {'min': 0, 'max': xmax},
    'z': {'min': 0, 'max': zmax},
}
frequency_df = get_frequency(df_train, restrictions)

In [None]:
plot_surface(frequency_df, filename=OUTPUT_PATH/"frequency_plot_marmousi_f30.html")

In [None]:
df_train_freq = add_frequency_to_data(df_train, frequency_df)
df_train_freq['normalized_frequency'] = (
    (df_train_freq['frequency'] - df_train_freq['frequency'].min()) / (df_train_freq['frequency'].max() - df_train_freq['frequency'].min()))
df_train_freq['pi_weight'] = 1 / df_train_freq['normalized_frequency']
df_train_freq['pi_weight'] = df_train_freq['pi_weight'].clip(0, 10)
df_train_freq.head()

In [None]:
df_train_freq.describe()

In [None]:
print(df_train_freq[arch_features+target].head().to_latex(index=False, float_format="{:0.2f}".format, escape=True))

In [None]:
fig = go.Figure(data=[go.Scatter3d(
    x=df_train_freq['x'],
    y=df_train_freq['z'],
    z=df_train_freq['pi_weight'],
    mode='markers',
    marker=dict(
        size=5,
        color=df_train_freq['pi_weight'],  # Color by pi_weight
        colorscale='Viridis',
        opacity=0.8
    )
)])

# Add labels
fig.update_layout(
    scene=dict(
        xaxis_title='x',
        yaxis_title='z',
        zaxis_title='pi_weight'
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)

fig.write_html(OUTPUT_PATH/"train_pi_weight_plot_marmousi_f30.html")

### Creating the DataLoaders

#### Balanced Batch Sampler Definition

We use a custom `BalancedBatchSampler` to ensure that each training batch contains a consistent proportion of samples representing the initial conditions (where time `t=0`) and samples from the rest of the trajectory. This is crucial because our physics-informed loss function includes a Mean Squared Error (MSE) term that specifically penalizes deviations from these initial conditions. By guaranteeing that `t=0` points are present in every batch, we ensure that this part of the loss function is consistently applied, leading to more stable training and a model that accurately respects the initial state of the system.

In [None]:
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, t_index=3, shuffle=True):
        """
    A sampler that generates balanced batches from a dataset, ensuring that each batch contains
    an equal number of samples with a specific target value.
    Args:
        dataset (Dataset): The dataset to sample from.
        batch_size (int): The number of samples per batch.
        t_index (int, optional): The index of the target value in the dataset samples. Default is 4.
        shuffle (bool, optional): Whether to shuffle the batch indices. Default is True.
        
    Attributes:
        dataset (Dataset): The dataset to sample from.
        batch_size (int): The number of samples per batch.
        t_index (int): The index of the target value in the dataset samples.
        shuffle (bool): Whether to shuffle the batch indices.
        t_zero_indices (list): Indices of samples with target value equal to 0.
        non_t_zero_indices (list): Indices of samples with target value not equal to 0.
        
    Methods:
        __iter__(): Generates balanced batches of indices.
        __len__(): Returns the number of batches.
    """
        if batch_size % 2 != 0 or batch_size <= 0:
            raise ValueError("batch_size must be a positive even integer.")
        
        self.dataset = dataset
        self.batch_size = batch_size
        self.t_index = t_index
        self.shuffle = shuffle

        # Split the dataset into two parts: with t=0 and without t=0
        self.t_zero_indices = []
        self.non_t_zero_indices = []

        for i in range(len(self.dataset)):
            data = self.dataset[i]
            try:
                t_sample = data[0][self.t_index]
            except (IndexError, TypeError):
                raise ValueError(f"The time index t_index={self.t_index} is not valid.")
            if t_sample == 0:
                self.t_zero_indices.append(i)
            else:
                self.non_t_zero_indices.append(i)

        if not self.t_zero_indices or not self.non_t_zero_indices:
            raise ValueError("There are not enough samples in each class to create balanced batches.")

    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.t_zero_indices)
            random.shuffle(self.non_t_zero_indices)

        t_zero_iter = cycle(self.t_zero_indices)
        non_t_zero_iter = cycle(self.non_t_zero_indices)

        num_batches = len(self.dataset) // self.batch_size

        for _ in range(num_batches):
            t_zero_batch = [next(t_zero_iter) for _ in range(2)] # Samples with t=0 per batch
            non_t_zero_batch = [next(non_t_zero_iter) for _ in range(self.batch_size-2)]
            batch_indices = t_zero_batch + non_t_zero_batch

            if self.shuffle:
                random.shuffle(batch_indices)

            yield batch_indices

    def __len__(self):
        return len(self.dataset) // self.batch_size


#### Separating Dependent and Independent Variables

The dataset is divided into independent variables (inputs/features) and dependent variables (outputs/targets) for each set (train, validation, and test). This separation is essential for training and evaluating machine learning models. The variables used are:

- **Independent variables (features):**
    - `arch_features` for training (includes KAN features and additional physics-informed features)
    - `kan_features` for validation and test (only KAN input features)

- **Dependent variables (targets):**
    - `target` (the physical quantities to be predicted: x, z, px, pz)

The tensors `X_train`, `X_val`, `X_test`, `y_train`, `y_val`, and `y_test` are created from the corresponding DataFrames for use in PyTorch models.

In [None]:
# Converting data to tensors
X_train = torch.as_tensor(df_train_freq[arch_features].values, dtype=torch.float32)
y_train = torch.as_tensor(df_train_freq[target].values, dtype=torch.float32)

X_val = torch.as_tensor(df_val[kan_features].values, dtype=torch.float32)
y_val = torch.as_tensor(df_val[target].values, dtype=torch.float32)

X_test = torch.as_tensor(df_test[kan_features].values, dtype=torch.float32)
y_test = torch.as_tensor(df_test[target].values, dtype=torch.float32)

# Creating the Dataset
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

# Creating a loader for each dataset
sampler = BalancedBatchSampler(train_dataset, batch_size=256)
train_loader = DataLoader(dataset=train_dataset,
                          batch_sampler=sampler,
                        #   shuffle=False,
                        #   batch_size=256,
                          )

val_loader = DataLoader(dataset=val_dataset,
                        shuffle=False,
                        batch_size=256
                        )

test_loader = DataLoader(dataset=test_dataset,
                        shuffle=False,
                        batch_size=len(test_dataset)
                        )


## Physics Loss Function

The `physics_loss_fn` is the core of the Physics-Informed Neural Network (PINN) approach, responsible for ensuring that the model's predictions adhere to the underlying physical principles of ray tracing. This function computes a composite loss that combines two critical components:

1.  **Initial Condition Loss**: It first verifies that the model correctly reproduces the initial state of the system. For any data points where time `t=0`, it calculates the Mean Squared Error (MSE) between the model's predicted position `(x, z)` and the known initial position `(x0, z0)`. This forces the learned trajectories to start at the correct locations.

2.  **Physics Residual Loss**: It then enforces the governing differential equations of ray tracing. Using automatic differentiation (`torch.autograd.grad`), it computes the derivatives of the model's outputs (x, z, px, pz) with respect to time `t`. These neural network-derived gradients are compared against the true derivatives provided by a classical numerical solver. The discrepancy between them forms the "physics residual." This residual is weighted to account for non-uniform data density, ensuring that the model learns the physics consistently across the entire domain.

The final loss is the sum of the initial condition loss and the weighted physics residual loss. By minimizing this composite loss, the model learns to generate trajectories that are not only consistent with the training data but also physically plausible according to the ray tracing equations.

In [None]:
# Adjusted physical loss function
def physics_loss_fn(model, input_tensor, *args, **kwargs):
    # Separate the inputs and ensure that the variables require gradient
    x0 = input_tensor[:, 0].view(-1, 1)  # x0
    z0 = input_tensor[:, 1].view(-1, 1)  # z0
    theta0 = input_tensor[:, 2].view(-1, 1)  # theta0
    t_var = input_tensor[:, 3].view(-1, 1)  # t
    
    # Weights for each sample
    weights = input_tensor[:, 4].view(-1, 1)
    
    # Getting the derivatives from the non-parametric model
    dxdt_true = input_tensor[:, 5].view(-1, 1)
    dzdt_true = input_tensor[:, 6].view(-1, 1)
    dpxdt_true = input_tensor[:, 7].view(-1, 1)
    dpzdt_true = input_tensor[:, 8].view(-1, 1)

    # Make sure all variables have requires_grad=True to calculate derivatives
    t_var.requires_grad_(True)

    # Recalculate yhat using the variables to ensure the connection
    x_reconstructed = torch.cat([x0, z0, theta0, t_var], dim=1)
    yhat = model(x_reconstructed)
    
    x_out = yhat[:, 0].view(-1, 1)
    z_out = yhat[:, 1].view(-1, 1)
    px_out = yhat[:, 2].view(-1, 1)
    pz_out = yhat[:, 3].view(-1, 1)
    
    # If t exists in the input tensor
    t_zero_mask = (t_var == 0).squeeze()
    if t_zero_mask.any():
        x0_t0 = x0[t_zero_mask]
        z0_t0 = z0[t_zero_mask]
        
        # Ensure that in t = 0, x = x0 and z = z0
        x_initial_condition = (yhat[t_zero_mask, 0] - x0_t0) ** 2
        z_initial_condition = (yhat[t_zero_mask, 1] - z0_t0) ** 2
        initial_condition_loss = torch.mean(x_initial_condition + z_initial_condition)
    else:
        # If t=0 is not present, the loss is 0
        initial_condition_loss = 0
    

    # Calculate the necessary gradients
    dxdt = torch.autograd.grad(
        outputs=x_out,  # Derivative of x with respect to t
        inputs=t_var,
        grad_outputs=torch.ones_like(x_out),
        create_graph=True,
        retain_graph=True
    )[0]

    dzdt = torch.autograd.grad(
        outputs=z_out,  # Derivative of z with respect to t
        inputs=t_var,
        grad_outputs=torch.ones_like(z_out),
        create_graph=True,
        retain_graph=True
    )[0]

    dpxdt = torch.autograd.grad(
        outputs=px_out,  # Derivative of px with respect to t
        inputs=t_var,
        grad_outputs=torch.ones_like(px_out),
        create_graph=True,
        retain_graph=True
    )[0]

    dpzdt = torch.autograd.grad(
        outputs=pz_out,  # Derivative of pz with respect to t
        inputs=t_var,
        grad_outputs=torch.ones_like(pz_out),
        create_graph=True,
        retain_graph=True
    )[0]

    # Differential equations according to your figure
    pde1 = dxdt - dxdt_true
    pde2 = dzdt - dzdt_true
    pde3 = dpxdt - dpxdt_true
    pde4 = dpzdt - dpzdt_true

    # Calculate the loss as the sum of the squared errors of the PDEs
    residual = pde1**2 + pde2**2 + pde3**2 + pde4**2
    loss = torch.sum(weights * residual) / torch.sum(weights)
    return loss + initial_condition_loss


## Best Model

With the optimal hyperparameters determined, we can now construct and train the definitive Physics-Informed Kolmogorov-Arnold Network (PIKAN) model.

This involves several key steps:
1.  **Model Definition**: We define the KAN architecture with a specific structure: `width=[4, 12, 6, 4]`, a grid size of `12`, and a spline order `k=3`. This configuration was identified as providing a good balance between expressiveness and complexity for this problem.
2.  **Optimizer and Scheduler**: We use the `Adam` optimizer with a learning rate of `1e-2`. To facilitate stable convergence, a `ReduceLROnPlateau` learning rate scheduler is also employed, which will decrease the learning rate if the validation loss plateaus.
3.  **Loss Function**: The training will be guided by a composite loss strategy. We use a standard `MSELoss` for the data-fidelity term and our custom `physics_loss_fn` to enforce the physical constraints of ray tracing.
4.  **Architecture Instantiation**: All these components—the model, optimizer, scheduler, and loss functions—are encapsulated within our `Architecture` helper class. This class manages the training loop, applies the physics loss with its specific weight (`lambda_physics=1e-3`), and handles early stopping to prevent overfitting.

In [None]:
torch.manual_seed(42)

input_size = len(kan_features)
output_size = len(target)

model = KAN(width=[input_size, 12, 6, output_size],
            grid=12,
            grid_range=[-5, 5],
            k=3,
            auto_save=False,
            ckpt_path=str(CHECKPOINT_PATH),
            seed=SEED,
            device=DEVICE)

# Defines optimizer
optimizer = partial(optim.Adam, lr=1e-2)
scheduler = partial(optim.lr_scheduler.ReduceLROnPlateau,
                    mode='min',
                    factor=0.1,
                    patience=5,
                    min_lr=1e-6)


# Defines a MSE loss function
loss_fn = torch.nn.MSELoss(reduction='mean')

In [None]:
arch = Architecture(model=model,
                    loss_fn=loss_fn,
                    physics_fn=physics_loss_fn,
                    partial_optimizer=optimizer,
                    partial_scheduler=scheduler,
                    use_weighted_pi=True, 
                    lamb=0,
                    lamb_l1=0,
                    lamb_entropy=0,
                    lamb_coef=0,
                    lamb_coefdiff=0,
                    lambda_physics=1e-3,
                    singularity_avoiding=True,
                    device=DEVICE)

arch.set_loaders(train_loader, val_loader)
arch.set_early_stopping(patience=10)

In [None]:
n_epochs = 100
arch.train(n_epochs=n_epochs, seed=SEED)
arch.save_checkpoint(model_save_path)

In [None]:
fig = arch.plot_losses()
fig.show()

## Evaluation


After training the model, the next crucial step is to evaluate its performance on unseen data. This process allows us to assess how well the model has generalized from the training set to new, independent examples.

Our evaluation pipeline involves the following steps:
1.  **Load the Trained Model**: We first load the best-performing model checkpoint that was saved during training.
2.  **Prepare the Test Dataset**: We use the `test_loader`, which contains data that the model has never seen before.
3.  **Make Predictions**: The trained model is used to predict the ray trajectories for the inputs in the test set.
4.  **Quantitative Analysis**: We calculate quantitative metrics to measure the discrepancy between the model's predictions and the ground truth from the numerical solver.
5.  **Qualitative Analysis**: We generate plots to visually compare the predicted ray paths against the true paths, overlaid on the Marmousi velocity model. This provides an intuitive understanding of the model's accuracy and where errors might be occurring.

This comprehensive evaluation ensures that we have a robust understanding of the model's capabilities and limitations.


In [None]:
torch.manual_seed(42)

input_size = len(kan_features)
output_size = len(target)

trained_model = KAN(width=[input_size, 12, 6, output_size],
                    grid=12,
                    grid_range=[-5, 5],
                    k=3,
                    auto_save=False,
                    ckpt_path=str(CHECKPOINT_PATH),
                    seed=SEED,
                    device=DEVICE)

# Defines optimizer
optimizer = partial(optim.Adam, lr=1e-2)
scheduler = partial(optim.lr_scheduler.ReduceLROnPlateau,
                    mode='min',
                    factor=0.1,
                    patience=5,
                    min_lr=1e-6)


# Defines a MSE loss function
loss_fn = torch.nn.MSELoss(reduction='mean')

trained_arch = Architecture(model=trained_model,
                            loss_fn=loss_fn,
                            physics_fn=None,
                            partial_optimizer=optimizer,
                            partial_scheduler=scheduler,
                            use_weighted_pi=True,
                            lamb=0,
                            lamb_l1=0,
                            lamb_entropy=0,
                            lamb_coef=0,
                            lamb_coefdiff=0,
                            lambda_physics=1e-3,
                            singularity_avoiding=True,
                            device=DEVICE
                            )

In [None]:
trained_arch.load_checkpoint(model_save_path)
df_test = pd.read_csv(DATA_PATH/'test_marmousi_f30.csv')

In [None]:
predictions = trained_arch.predict(df_test[kan_features].values)
df_pred = pd.DataFrame(predictions, columns=target)
df_pred = df_test[kan_features].join(df_pred)
df_pred.head()

In [None]:
score(df_pred[target].values, df_test[target].values)

In [None]:
# Number of desired subplots
unique_initial_conditions = (df_test[kan_features]
                             .drop(columns=['t'])
                             .drop_duplicates()
                             .sample(9, random_state=SEED)
                             )
n_initial_conditions = unique_initial_conditions.shape[0]
cols = 3
rows = (n_initial_conditions + cols - 1) // cols
fig, axs = plt.subplots(rows, cols, figsize=(22, 12))
axs = axs.flatten()  # Flatten to iterate easily

# Find the minimum and maximum values of data_gen.Vbsplines to set the color bar dynamically
vmin, vmax = np.min(data_gen.Vbsplines), np.max(data_gen.Vbsplines)

for ax, (x0, z0, theta0_p) in zip(axs[:n_initial_conditions], unique_initial_conditions.values):
    initial_condition = (
        (df_test['x0'] == x0) &
        (df_test['z0'] == z0) &
        (df_test['theta0_p'] == theta0_p)
    )

    im = ax.imshow(np.flipud(data_gen.Vbsplines), extent=[data_gen.x_range[0],
                                                          data_gen.x_range[1],
                                                          data_gen.z_range[0],
                                                          data_gen.z_range[-1]],
                   vmin=vmin, vmax=vmax)
    ax.set_title(f'{x0=:.2f}, {z0=:.2f}, {theta0_p=:.2f}')
    ax.plot(df_test.loc[initial_condition, 'x'],
            df_test.loc[initial_condition, 'z'],
            color='grey',
            linewidth=2,
            label='True Path'
            )
    ax.plot(df_pred.loc[initial_condition, 'x'],
            df_pred.loc[initial_condition, 'z'],
            color='black',
            linewidth=2,
            label='Predicted Path'
            )
    for zi in data_gen.z:
        ax.plot(data_gen.x, zi *
                np.ones_like(data_gen.x), 'kx', linewidth=2)

    ax.set_xlabel('x (km)')
    ax.set_ylabel('z (km)')
    ax.invert_yaxis()
    
    ax.set_xlim([3.5, 6.5])
    ax.set_ylim([2.5, 0])

# Add a general color bar for the entire figure
# Set the position of the color bar
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cbar_ax)

# Add a single legend outside the subplot area
legend_elements = [
    Line2D([0], [0], color='grey', lw=2, label='True Path'),
    Line2D([0], [0], color='black', lw=2, label='Predicted Path'),
    Line2D([0], [0], color='k', marker='x', lw=0, label='Z markers')
]
fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.05, 0.5))

fig.suptitle('Visualization of Velocity Models and Predicted Paths', fontsize=16, y=1.02)

for ax in axs[n_initial_conditions:]:
    ax.axis('off')  # Disable axes for empty subplots

# Adjust layout to avoid overlap
plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.subplots_adjust(
    left=0.05,
    top=0.92,
    bottom=0.08,
    wspace=0.0,  # horizontal space between subplots
    hspace=0.30   # vertical space between subplots
)  # Adjust space for the legend
plt.show()


In [None]:
# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Plot the background map: Vbsplines
im = ax.imshow(np.flipud(data_gen.Vbsplines), 
               extent=[data_gen.x_range[0], data_gen.x_range[1],
                       data_gen.z_range[0], data_gen.z_range[-1]],
               vmin=np.min(data_gen.Vbsplines),
               vmax=np.max(data_gen.Vbsplines),
               aspect='auto')

unique_initial_conditions = (df_test[kan_features]
                             .drop(columns=['t'])
                             .drop_duplicates()
                             .sample(6, random_state=SEED)
                             )

# Plot all real trajectories

for x0, z0, theta0_p in unique_initial_conditions.values:
    plot_test = (
        (df_test['x0'] == x0) &
        (df_test['z0'] == z0) &
        (df_test['theta0_p'] == theta0_p)
    )
    plot_pred = (
        (df_pred['x0'] == x0) &
        (df_pred['z0'] == z0) &
        (df_pred['theta0_p'] == theta0_p)
    )
    ax.plot(df_test[plot_test]['x'],
            df_test[plot_test]['z'],
            color='grey',
            marker='x',
            markevery=100,
            linestyle='--',
            linewidth=1.5
            )
    ax.plot(df_pred[plot_pred]['x'],
            df_pred[plot_pred]['z'],
            color='black',
            linewidth=1.5)

# Mark the Z points
for zi in data_gen.z:
    ax.plot(data_gen.x, zi * np.ones_like(data_gen.x), 'kx', linewidth=1)

# Graph adjustments
ax.set_xlabel('x (km)')
ax.set_ylabel('z (km)')
# ax.set_title('Runge-Kutta vs PIKAN predicted paths')
ax.invert_yaxis()
ax.set_xlim([3.5, 6.5])
ax.set_ylim([2.5, 0])

# Add color bar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('Velocity (km/s)')

# Add legend
legend_elements = [
    Line2D([0], [0], color='grey', lw=2, label='True Path'),
    Line2D([0], [0], color='black', lw=2, label='Predicted Path'),
    Line2D([0], [0], color='k', marker='x', lw=0, label='Z markers')
]
ax.legend(handles=legend_elements)

plt.tight_layout()
plt.show()


### Error Analysis

First we have to calculate the mean absolute percentage error for each ray trajectory.

In [None]:
df_pred[initial_condition]

In [None]:
unique_initial_conditions = (df_test[kan_features]
                             .drop(columns=['t', 'theta0_p'])
                             .drop_duplicates())

df_error = pd.DataFrame()

for _, row in unique_initial_conditions.iterrows():
    x0, z0 = row
    initial_condition = (
        (df_pred['x0'] == x0) &
        (df_pred['z0'] == z0)
    )
    score_value = score(df_pred[initial_condition][target].values,
                        df_test[initial_condition][target].values)
    mse = score_value.iloc[2, :]
    aux = row.copy()
    aux['x.mse'] = mse.iloc[0]
    aux['z.mse'] = mse.iloc[1]
    aux['px.mse'] = mse.iloc[2]
    aux['pz.mse'] = mse.iloc[3]
    df_error = pd.concat([df_error, aux.to_frame().T], ignore_index=True)

df_error.head()

In [None]:
mean_differences = (
    df_error.groupby(['x0', 'z0'], as_index=False)[['x.mse', 'z.mse']]
    .mean()
)
mean_differences


In [None]:
def plot_contour(df: pd.DataFrame, col: str) -> None:
    # Organizing data for the 2D surface plot matrix
    x = np.sort(df['x0'].unique())  # Unique sorted x0 values
    z = np.sort(df['z0'].unique())  # Unique sorted z0 values

    # Create a matrix for the mean error values
    z_matrix = np.zeros((len(x), len(z)))

    # Fill the matrix with mean error values
    for i, x_val in enumerate(x):  # Iterate over x0 values
        for j, z_val in enumerate(z):  # Iterate over z0 values
            filter_cond = (df['x0'] == x_val) & (
                df['z0'] == z_val)
            if filter_cond.any():  # If at least one match
                z_matrix[i, j] = df.loc[filter_cond, col].values[0]

    # Create the 2D contour plot (Surface plot in 2D)
    fig = go.Figure()

    fig.add_trace(
        go.Heatmap(
            z=np.flipud(data_gen.Vbsplines),
            x=np.linspace(0, xmax, vp.shape[1]),   # Assign the x axis range
            y=np.linspace(0, zmax, vp.shape[0]),   # Assign the z axis range
            colorscale='Viridis',  # Set the color scale
            showscale=False
        ))


    fig.add_trace(
        go.Contour(
            z=z_matrix,
            x=x,
            y=z,
            colorscale='Inferno',
            contours=dict(
                showlines=True,  # Show contour lines
                coloring='lines',  # Color only the lines
            ),
            line=dict(
                width=5,  # Thicker contour lines
            ),
            colorbar=dict(
            tickformat=".2e",                # ← scientific notation
            tickfont=dict(size=16),
            exponentformat="e"               # 1e+03 style (optional)
        )
        ))

    # Customize the layout of the plot
    fig.update_layout(
        xaxis_title='x0 (km)',
        yaxis_title='z0 (km)',
        width=800, height=700,
        margin=dict(l=20, r=20, b=20, t=30),  # Tight margin to reduce unused space
        yaxis=dict(
            range=[3, 1],  # Reverse the Y-axis (z0 values)
            tickfont=dict(size=16),
        ),
        xaxis=dict(
            range=[4, 5],  # Reverse the X-axis (x0 values)
            tickfont=dict(size=16),
        ),
        xaxis_title_font=dict(size=20),  # Increase X-axis title font size
        yaxis_title_font=dict(size=20),  # Increase Y-axis title font size
    )

    # Show the plot
    fig.show()


In [None]:
plot_contour(mean_differences, 'x.mse')

In [None]:
plot_contour(mean_differences, 'z.mse')

### Prediction time comparison

First, let's evaluate the Runge-Kutta method.

In [None]:
data_gen = DataGeneratorMarmousi(
    x_range=x_range,
    z_range=z_range
)

In [None]:
unique_initial_conditions = (df_test[kan_features]
                             .drop(columns=['t'])
                             .drop_duplicates())
unique_initial_conditions

Here, we are tracing 323 ray trajectories.

In [None]:
%%timeit -n 3 -r 7

rays = data_gen.run_batch(x0_vec=unique_initial_conditions['x0'],
                          z0_vec=unique_initial_conditions['z0'],
                          theta0_vec=unique_initial_conditions['theta0_p'],
                          vp=vp,
                          factor=30,
                          t_max=0.4
                          )

The Second order Runge-Kutta method took `29 s ± 99.4 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)`. But 13 seconds are basically the B-splines filter, so we can consider a 16 s ± 99.4 ms to be fair.

Now, we can compare with the proposed PIKAN model.

In [None]:
%%timeit -n 3 -r 7

predictions = arch.predict(df_test[kan_features].values)

Notably, the proposed PIKAN model executed in just `399 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 3 loops each)`, making it approximately 40× faster than the baseline implementation.