# Example notebook using the Pykonal solver for the 3D Eikonal equation

### Import modules

In [None]:
%matplotlib ipympl
import matplotlib.gridspec
import matplotlib.pyplot as plt
import numpy as np
import pykonal

### Define function to plot results

In [None]:
def plot(solver, ix=None, iy=None, iz=None, attr='uu', rays=None, cbar_label='Travel-time [s]'):
    grid                = pykonal.GridND(ndim=3)
    grid.min_coords     = solver.pgrid.min_coords
    grid.node_intervals = solver.pgrid.node_intervals
    grid.npts           = solver.pgrid.npts + 1
    if ix is None:
        ix = int(grid.npts[0] / 2) - 1
    if iy is None:
        iy = int(grid.npts[1] / 2) - 1
    if iz is None:
        iz = int(grid.npts[2] / 2) - 1
    data = getattr(solver, attr)
    data_xy = data[:, :, iz]
    data_xz = data[:, iy, :]
    data_yz = data[ix, :, :]
    vmin = np.min(np.concatenate([data_xy.flatten(), data_xz.flatten(), data_yz.flatten()]))
    vmax = np.max(np.concatenate([data_xy.flatten(), data_xz.flatten(), data_yz.flatten()]))
    dx, dy, dz = grid.max_coords - grid.min_coords
    dmax = np.max([dx, dy, dz])
    aspect = (dx + dy) / (dz + dy)
    gs = matplotlib.gridspec.GridSpec(2, 2,
                           width_ratios=[dx/dmax, dy/dmax],
                           height_ratios=[dz/dmax, dy/dmax]
                           )
    fig = plt.figure(figsize=(aspect*8+0.3, aspect*8))
    ax1 = plt.subplot(gs[0], aspect=1)
    ax2 = plt.subplot(gs[1], aspect=1)
    ax3 = plt.subplot(gs[2], aspect=1)
    
    gs = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[3], height_ratios=[1, 10])
    cax = plt.subplot(gs[0])

    kwargs = dict(
        cmap=plt.get_cmap('jet_r'),
        vmin=vmin,
        vmax=vmax
    )
    qmesh = ax1.pcolormesh(
        grid[:, iy, :, 0] - grid.node_intervals[0] / 2, 
        grid[:, iy, :, 2] - grid.node_intervals[2] / 2,
        data_xz, 
        **kwargs
    )
    ax1.axhline(grid[0, 0, iz, 2], color='w')
    ax1.axvline(grid[ix, 0, 0, 0], color='w')
    if rays is not None:
        for ray in rays:
            ax1.plot(ray[:, 0], ray[:, 2], 'k--')
    ax1.xaxis.tick_top()
    ax1.xaxis.set_label_position('top')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Z')

    ax2.pcolormesh(
        grid[ix, :, :, 1] - grid.node_intervals[1] / 2, 
        grid[ix, :, :, 2] - grid.node_intervals[2] / 2,
        data_yz, 
        **kwargs
    )
    ax2.axvline(grid[0, iy, 0, 1], color='w')
    if rays is not None:
        for ray in rays:
            ax2.plot(ray[:, 1], ray[:, 2], 'k--')
    ax2.xaxis.tick_top()
    ax2.xaxis.set_label_position('top')
    ax2.yaxis.tick_right()
    ax2.yaxis.set_label_position('right')
    ax2.set_xlabel('Y')
    ax2.set_ylabel('Z')

    qmesh = ax3.pcolormesh(
        grid[:, :, iz, 0] - grid.node_intervals[0] / 2, 
        grid[:, :, iz, 1] - grid.node_intervals[1] / 2, 
        data_xy, 
        **kwargs
    )
    ax3.axhline(grid[0, iy, 0, 1], color='w')
    if rays is not None:
        for ray in rays:
            ax3.plot(ray[:, 0], ray[:, 1], 'k--')
    ax3.invert_yaxis()
    ax3.set_xlabel('X')
    ax3.set_ylabel('Y')
    
    cbar = fig.colorbar(qmesh, cax=cax, orientation='horizontal')
    cbar.set_label(cbar_label)

    fig.tight_layout()

### Define function to instantiate EikonalSolver with uniform velocity model

In [None]:
def init_solver():
    # Initialize the solver
    solver = pykonal.EikonalSolver()

    # Initialize the velocity grid with a uniform velocity model.
    # EikonalSolver.vgrid.min_coords specifies the minimum coordinates of the velocity grid
    solver.vgrid.min_coords     = 0, 0, 0    # xmin, ymin, zmin
    # EikonalSolver.vgrid.node_intervals specifies the spacing between velocity grid nodes
    solver.vgrid.node_intervals = 1, 1, 1    # dx, dy, dz
    # EikonalSolver.vgrid.npts specifies the number of grid nodes along each axis
    solver.vgrid.npts           = 11, 11, 11 # nx, ny, nz
    # EikonalSolver.vv holds the velocity at each grid node and should be a numpy.ndarray
    # with shape == EikonalSolver.vgrid.npts.
    solver.vv                   = np.ones(solver.vgrid.npts)
    return (solver)

## Simplest case

In [None]:
# Initialize the solver
solver = init_solver()

# Add a source in the center of the computational domain.
src_idx = (5, 5, 5)            # The source location as an array index
solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
solver.close.push(*src_idx)    # Push the source index onto the close heap

# Solve the Eikonal equation.
solver.solve()

# Plot the resulting travel-time field.
plot(solver)

## The propagation grid does not have to coincide with the velocity grid
You may want to make it denser for more accurate solutions. Just make sure the boundaries of the propagation grid fall within the boundaries of the velocity grid, otherwise the EikonalSolver will raise an *OutOfBoundsError* because it needs to interpolate the velocity at each node of the propagation grid.

In [None]:
# Initialize the solver
solver = init_solver()
# Decrease the node interval by a factor of 2
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 2
# And increase the number of points by a factor of 2, making sure to not go beyond the
# boundaries of the velocity grid
solver.pgrid.npts           = solver.vgrid.npts * 2 - 1


# Add a source in the center of the computational domain.
src_idx = (10, 10, 10)            # The source location as an array index
solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
solver.close.push(*src_idx)    # Push the source index onto the close heap

# Solve the Eikonal equation again.
solver.solve()

# And plot the results.
plot(solver)

## There is no real limit to how dense you can make the propagation grid...
...if you're willing to be patient.

In [None]:
solver = init_solver()
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 10
solver.pgrid.npts           = solver.vgrid.npts * 10 - 9

# Add a source in the center of the computational domain.
src_idx = (50, 50, 50)            # The source location as an array index
solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
solver.close.push(*src_idx)    # Push the source index onto the close heap

solver.solve()
plot(solver)

## And you can move the source if you would like

In [None]:
solver = init_solver()
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 10
solver.pgrid.npts           = solver.vgrid.npts * 10 - 9

src_idx = (25, 50, 75)            # The source location as an array index
solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
solver.close.push(*src_idx)    # Push the source index onto the close heap

solver.solve()
plot(solver)

## Or add multiple sources

In [None]:
solver = init_solver()
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 10
solver.pgrid.npts           = solver.vgrid.npts * 10 - 9

for src_idx in ((25, 50, 75), (75, 50, 25)):
    solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
    solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
    solver.close.push(*src_idx)    # Push the source index onto the close heap

solver.solve()
plot(solver)

## And sources at non-zero times

In [None]:
solver = init_solver()
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 10
solver.pgrid.npts           = solver.vgrid.npts * 10 - 9

for src_idx, t0 in (((25, 50, 75), 0), ((75, 50, 25), 2.5)):
    solver.uu[src_idx]     = t0     # Set the traveltime at the source location to t0
    solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
    solver.close.push(*src_idx)    # Push the source index onto the close heap

solver.solve()
plot(solver)
solver.solve()
plot(solver)

## You can use more interesting velocity models

In [None]:
solver = init_solver()

vy = np.linspace(1, 5, solver.vgrid.npts[1])
for iy in range(len(vy)):
    solver.vv[:,iy] = vy[iy]
    
solver.pgrid.node_intervals = solver.vgrid.node_intervals / 10
solver.pgrid.npts           = solver.vgrid.npts * 10 - 9

src_idx = (25, 50, 75)
solver.uu[src_idx]     = 0     # Set the traveltime at the source location to zero
solver.is_far[src_idx] = False # Set the is_far flag to False for the source node
solver.close.push(*src_idx)    # Push the source index onto the close heap

solver.solve()

# This will plot the velocity model
plot(solver, attr='vvp', cbar_label='Velocity [km/s]')
plot(solver)

## And you can trace rays too

In [None]:
ray = solver.trace_ray((9.5, 0, 9.5))
plot(solver,rays=[ray])

## Use the pykonal.LinearInterpolator3D class to interpolate the velocity field at arbitrary locations

In [None]:
ui = pykonal.LinearInterpolator3D(solver.pgrid, solver.uu)
ui((4.25, 3.1, 2.98))