# SensRay MeshEarthModel Demo

This demo showcases the core features of `MeshEarthModel`:
- Build a tetrahedral sphere mesh
- Map properties from a 1D Earth model (e.g., PREM) onto mesh cells
- Visualize constant-per-cell values on a spherical shell and on great-circle slices
- Compute per-cell path lengths for a synthetic ray and derive sensitivity kernels
- Overlay points (e.g., source and receiver) on plots

In [9]:
import numpy as np
from sensray.mesh.earth_model import MeshEarthModel
from sensray.core.ray_paths import RayPathTracer

## 1) Create a tetrahedral sphere mesh
We'll generate a reasonably coarse sphere to keep things fast in a notebook.

In [10]:
# Adjust mesh_size_km to trade accuracy vs speed
mesh_model = MeshEarthModel.from_pygmsh_sphere(mesh_size_km=500)
mesh_model.radius_km

6371.0

## 2) Map 1D Earth model properties onto mesh cells
We'll add S-wave and P-wave velocities (`vs`, `vp`) from the selected 1D model.

In [11]:
model_name = 'prem'  # TauP 1D model name
mesh_model.add_scalars_from_1d_model(model_name, properties=('vs','vp'), where='cell')
# Quick sanity check: show available cell_data keys
list(mesh_model.mesh.cell_data.keys())

['vs', 'vp']

## 3) Visualize a spherical shell with constant-per-cell coloring
The shell geometry and wireframe reflect the underlying mesh, and coloring is constant per parent cell.

In [12]:
mesh_model.plot_sphere(radius_km=5000, scalar_name='vs', wireframe=True).show()

## 4) Great-circle slice with constant-per-cell coloring
Polygons on the slice inherit their parent cell values for crisp, non-interpolated coloring.

In [13]:
source_lat, source_lon, source_depth = 0.0, 0.0, 10.0
receiver_lat, receiver_lon = 0.0, 80.0
p = mesh_model.plot_slice(
    source_lat=source_lat,
    source_lon=source_lon,
    receiver_lat=receiver_lat,
    receiver_lon=receiver_lon,
    scalar_name='vs',
    cmap='RdBu',
    wireframe=True,
)
p.show()

## 5) Compute a synthetic ray, path lengths, and sensitivity kernel
We'll create a straight chord ray between two near-surface points and compute per-cell path lengths and a `K_vs` kernel.
Note: This is a synthetic example; real rays would typically be curved.

In [14]:
# Build a straight chord between two points near the surface
tracer = RayPathTracer(model_name=model_name)
rays = tracer.get_ray_paths(source_lat=source_lat, source_lon=source_lon, source_depth=source_depth,
                            receiver_lat=receiver_lat, receiver_lon=receiver_lon, phases=['P'])
rays_coords = tracer.extract_ray_coordinates(rays[0])

# Stack into a 3 x N ndarray (rows: x, y, z)
ray_xyz = np.vstack((
    rays_coords['P']['3d_x_cartesian'],
    rays_coords['P']['3d_y_cartesian'],
    rays_coords['P']['3d_z_cartesian'],
)).astype(float)

# quick check
print(ray_xyz)
# Compute per-cell path lengths
L = mesh_model.compute_ray_cell_path_lengths(ray_xyz, attach_name='ray_lengths', tol=1e-6)

# Compute sensitivity kernel K_vs = -L / (vs^2 + epsilon)
K_vs = mesh_model.compute_sensitivity_kernel(ray_xyz, property_name='vs', attach_name='K_vs', epsilon=1e-6)

# Quick summaries
float(L.sum()), int(np.count_nonzero(L)), float(np.nanmax(np.abs(K_vs)))

[[6.36100000e+03 6.35599983e+03 6.35486014e+03 ... 1.10652754e+03
  1.10642004e+03 1.10631254e+03]
 [0.00000000e+00 1.46634938e+00 1.86472440e+00 ... 6.27226834e+03
  6.27323927e+03 6.27421019e+03]
 [6.35100000e+03 6.34100000e+03 6.33872083e+03 ... 6.36725000e+03
  6.36912500e+03 6.37100000e+03]]


(6952.154093344296, 45, 21.091778870383205)

## 6) Visualize the kernel on a slice and on a shell
Because kernels are stored as cell scalars, the constant-per-cell rendering applies here too.

In [15]:
p = mesh_model.plot_slice(
    source_lat=source_lat,
    source_lon=source_lon,
    receiver_lat=receiver_lat,
    receiver_lon=receiver_lon,
    scalar_name='ray_lengths',
    cmap='viridis',
    wireframe=True,
)
p.show()

mesh_model.plot_sphere(radius_km=5000, scalar_name='K_vs', wireframe=True).show()

## 7) Overlay source and receiver points on a slice
The `add_points` helper accepts dict inputs with geographic coordinates and depth.

In [16]:
source_depth_km = 30.0
pl = mesh_model.plot_slice(
    source_lat=source_lat,
    source_lon=source_lon,
    receiver_lat=receiver_lat,
    receiver_lon=receiver_lon,
    scalar_name='vs',
    cmap='RdBu',
    wireframe=True,
)
mesh_model.add_points(pl, {
    'lat': [source_lat, receiver_lat],
    'lon': [source_lon, receiver_lon],
    'depth_km': [source_depth_km, source_depth_km],
})
pl.show()