# SensRay Ray Tracing & Sensitivity Kernels Demo

This notebook demonstrates:
- Ray tracing with ObsPy/TauP integration
- Computing per-cell ray path lengths
- Calculating sensitivity kernels
- Visualizing rays and kernels in cross-sections

In [None]:
import os
# Top-level INTERACTIVE switch: set to True for interactive Jupyter widgets,
# or False to force static/off-screen rendering and enable screenshot-based outputs.
INTERACTIVE = True

# If not interactive, set PyVista to off-screen mode. If interactive, prefer ipyvtklink/panel.
if not INTERACTIVE:
    os.environ['PYVISTA_OFF_SCREEN'] = 'true'
    os.environ['PYVISTA_USE_IPYVTK'] = 'false'
else:
    os.environ.pop('PYVISTA_OFF_SCREEN', None)
    # Prefer ipyvtklink for Jupyter interactive 3D plots (requires ipyvtklink installed)
    os.environ['PYVISTA_USE_IPYVTK'] = 'true'

import numpy as np
from sensray import PlanetModel, CoordinateConverter

# Configure PyVista backend according to INTERACTIVE flag
try:
    import pyvista as pv
    if INTERACTIVE:
        try:
            pv.set_jupyter_backend('ipyvtklink')
            print("Configured PyVista for interactive plotting using ipyvtklink")
        except Exception:
            try:
                pv.set_jupyter_backend('panel')
                print("Configured PyVista for interactive plotting using panel")
            except Exception as e:
                print(f"Warning: Could not set interactive PyVista backend: {e}")
    else:
        pv.set_jupyter_backend('static')
        print("Configured PyVista for static/off-screen plotting")
except Exception as e:
    print(f"Warning: Could not import/configure pyvista: {e}")

## 1. Setup Model and Mesh

Create a model and mesh for ray tracing experiments.

### Creating a layered tetrahedral mesh (discontinuities)

This demo uses a layered, concentric-sphere tetrahedral mesh so you can control resolution across major internal interfaces (discontinuities). The mesher accepts these main controls:

- `radii` (list of floats, ascending): interface radii in km (last entry must be the outer radius).
- `H_layers` (list of floats): target element size per layer (km). If `None`, `mesh_size_km` is used for all layers.
- `W_trans` (list of floats): half-widths for smooth size transitions at interfaces (km). If omitted, a default ~0.2*layer_thickness is used.

Notes:
- Units are kilometres throughout the API.
- For a uniform sphere just omit `radii` and set `mesh_size_km`.

Example (run this when creating a new mesh):
```python
radii = [1221.5, 3480.0, 6371.0]
H_layers = [500.0, 500.0, 300.0]
W_trans = [50.0, 100.0]  # optional
# Create mesh (use do_optimize=False for faster development)
model.create_mesh(mesh_size_km=1000.0, radii=radii, H_layers=H_layers, W_trans=W_trans, do_optimize=False)
model.mesh.populate_properties(['vp', 'vs', 'rho'])
model.mesh.save('prem_mesh.vtu')  # file extension recommended
```

In [None]:
# Load model and create mesh
model = PlanetModel.from_standard_model('M1')
# Create mesh and save if not exist, otherwise load existing
print(model.get_discontinuities())
mesh_path = "M1_mesh"
try:
    model.create_mesh(from_file=mesh_path)
    print(f"Loaded existing mesh from {mesh_path}")
except FileNotFoundError:
    print("Creating new mesh...")
    radii = [397.1, 1687.0, 1737.1]
    H_layers = [200, 200, 200]
    model.create_mesh(mesh_size_km=1000, radii=radii, H_layers=H_layers)
    model.mesh.populate_properties(['vp', 'vs', 'rho'])
    model.mesh.save("M1_mesh")  # Save mesh to VT
print(f"Created mesh: {model.mesh.mesh.n_cells} cells")

## 2. Define Source-Receiver Geometry

Set up a realistic earthquake-station pair.

In [None]:
# Define source (earthquake) and receiver (seismic station) locations
source_lat, source_lon, source_depth = 0.0, 0.0, 10.0  # Equator, 10 km depth
receiver_lat, receiver_lon = 70.0, 45.0  # Surface station

# Compute great-circle plane normal for cross-sections
plane_normal = CoordinateConverter.compute_gc_plane_normal(
    source_lat, source_lon, receiver_lat, receiver_lon
)
print(f"Source: ({source_lat}°, {source_lon}°, {source_depth} km)")
print(f"Receiver: ({receiver_lat}°, {receiver_lon}°, 0 km)")
print(f"Great-circle plane normal: {plane_normal}")

## 3. Ray Tracing with TauP

Compute ray paths for different seismic phases.

In [None]:
# Get ray paths for P and S waves
rays = model.taupy_model.get_ray_paths_geo(
    source_depth_in_km=source_depth,
    source_latitude_in_deg=source_lat,
    source_longitude_in_deg=source_lon,
    receiver_latitude_in_deg=receiver_lat,
    receiver_longitude_in_deg=receiver_lon,
    phase_list=["P", "PcP"]
)

rays.plot()

print(f"Found {len(rays)} ray paths:")
for i, ray in enumerate(rays):
    print(f"  {i+1}. {ray.phase.name}: {ray.time:.2f} s, {len(ray.path)} points")

## 4. Compute Ray Path Lengths

Calculate how much each ray travels through each mesh cell.

In [None]:
# Compute and store path lengths for each ray
P_ray = rays[0]  # First ray (P wave)
PcP_ray = rays[1]  # Second ray (PcP wave)

# Method 1: Simple computation and storage
P_lengths = model.mesh.add_ray_to_mesh(P_ray, "P_wave")
PcP_lengths = model.mesh.add_ray_to_mesh(PcP_ray, "PcP_wave")

print(f"P wave: {P_lengths.sum():.1f} km total, {np.count_nonzero(P_lengths)} cells")
print(f"PcP wave: {PcP_lengths.sum():.1f} km total, {np.count_nonzero(PcP_lengths)} cells")

# Show stored properties
ray_keys = [k for k in model.mesh.mesh.cell_data.keys() if 'ray_' in k]
print(f"Stored ray properties: {ray_keys}")

### And visualization

In [None]:
# Cross-section showing background Vp
print("Background P-wave velocity:")
plotter1 = model.mesh.plot_cross_section(
    plane_normal=plane_normal,
    property_name='vp',
    show_rays=rays,  # Overlay ray path
)
plotter1.camera.position = (8000, 6000, 10000)
# Use INTERACTIVE switch to control whether show() presents an interactive widget or a static render.
# If INTERACTIVE is False and you want an image, call plotter1.screenshot('vp_background.png')
plotter1.show(interactive=INTERACTIVE)


In [None]:
# Cross-section showing P-wave ray path lengths
print("P-wave path lengths through cells:")
plotter2 = model.mesh.plot_cross_section(
    plane_normal=plane_normal,
    property_name='ray_P_wave_P_lengths',
    show_rays=[P_ray],
)
plotter2.camera.position = (8000, 6000, 10000)
# Use INTERACTIVE switch
plotter2.show(interactive=INTERACTIVE)

## 5. Sensitivity Kernels

Compute travel-time sensitivity kernels: K = -L / v² for each cell.

In [None]:
# Compute sensitivity kernels for P and S waves
P_kernel = model.mesh.compute_sensitivity_kernel(
    P_ray, property_name='vp', attach_name='K_P_vp', epsilon=1e-6
)
PcP_kernel = model.mesh.compute_sensitivity_kernel(
    PcP_ray, property_name='vs', attach_name='K_PcP_vs', epsilon=1e-6
)

print(f"P kernel range: {P_kernel.min():.6f} to {P_kernel.max():.6f} s²/km³")
print(f"PcP kernel range: {PcP_kernel.min():.6f} to {PcP_kernel.max():.6f} s²/km³")
print(f"Non-zero P kernel cells: {np.count_nonzero(P_kernel)}")
print(f"Non-zero PcP kernel cells: {np.count_nonzero(PcP_kernel)}")

In [None]:
plotter3 = model.mesh.plot_cross_section(
    plane_normal,
    property_name='K_P_vp',
    show_rays=[P_ray],
    cmap='magma',
)
# Interactive display for kernel cross-section
plotter3.show(interactive=INTERACTIVE)

## 6. Multiple Ray Kernels

Combine kernels from multiple rays for enhanced sensitivity.

In [None]:
# Sum kernels from multiple rays
if len(rays) >= 2:
    combined_kernel = model.mesh.compute_sensitivity_kernel(
        rays,  # Use PcP rays
        property_name='vp',
        attach_name='K_combined_vp',
        accumulate='sum'
    )
    print(f"Combined kernel range: {combined_kernel.min():.6f} to {combined_kernel.max():.6f}")
    print(f"Combined kernel non-zero cells: {np.count_nonzero(combined_kernel)}")

In [None]:
plotter4 = model.mesh.plot_cross_section(
    plane_normal,
    property_name='K_combined_vp',
    show_rays=[P_ray, PcP_ray],
    cmap='magma',
)
# Interactive display
plotter4.show(interactive=INTERACTIVE)

## 8. Export Results

Save mesh with all computed properties for further analysis.

In [None]:
# Save mesh with rays and kernels
model.mesh.save('prem_mesh_with_rays_kernels')

# Show what was saved
info = model.mesh.list_properties(show_stats=False)
print(f"Saved {len(info['cell_data'])} properties to VTU file:")
for prop in info['cell_data'].keys():
    print(f"  - {prop}")

print("\nFiles created:")
print("  - prem_mesh_with_rays_kernels.vtu (mesh + all data)")
print("  - prem_mesh_with_rays_kernels_metadata.json (property list)")