# Full Waveform Inversion (FWI)

## 1. Introduction

The Full Waveform Inversion (FWI) techinique is designed to simulate a geophysical survey and estimate the model parameters (e.g., seismic velocity) to explain the observed waveforms in a way that minimizes a measure of error (e.g., misfit), process known as inversion.

FWI utilizes both amplitudes and phase information from recorded data and can thus image higher resolution targets to half the spatial wavelength of the source frequency (Fichtner, 2011).

The following shows a basic overview of an experimental configuration used in FWI in a marine environment.

<img src="fwi_setup.png" alt="Drawing" style="width: 600px;"/>

Source: [1].

`spyro` uses continuous Galerkin finite elements are applied to perform full waveform inversion (FWI) for seismic velocity model building. It is used a time-domain FWI approach with meshes composed of variably sized triangular elements to discretize the domain. To resolve both the forward and adjoint-state equations, and to calculate a mesh-independent gradient associated with the FWI process, a fully-explicit, variable higher-order (up to degree k = 5 in 2D and k = 3 in 3D) mass lumping method is used.

The adaptation of the triangular elements to the expected peak source frequency and properties of the wavefield (e.g., local P-wavespeed) and by leveraging higher-order basis functions, the number of degrees-of-freedom necessary to discretize the domain can be reduced. 

## 2. FWI 2D example

### 2.1. Importing libraries

For implementing FWI, the ROL library must be installed. First, activate the `Firedrake` environment. Then, install the `patchelf` library:

`sudo apt install patchelf`

Then install the Rapid Optimization Library along with roltrilinos:

`pip3 install --no-cache-dir roltrilinos`

`pip3 install --no-cache-dir ROL`

In [12]:
import firedrake as fire
import numpy as np
import finat
from ROL.firedrake_vector import FiredrakeVector as FeVector
import ROL
from mpi4py import MPI

import spyro

import psutil
import os

### 2.2. Memory usage analysis

Since FWI requires a high computational cost, the memory usage should be monitored. An output directory is also created.

In [13]:
def get_memory_usage():
    """Return the memory usage in Mo."""
    process = psutil.Process(os.getpid())
    mem = process.memory_info()[0] / float(2 ** 20)
    return mem


outdir = "fwi_p3_ho/"
if COMM_WORLD.rank == 0:
    mem = open(outdir + "mem.txt", "w")
    func = open(outdir + "func.txt", "w")

### 2.3. Model definition

`spyro` requires a model dictionary containing all relevant parameters and information. A simple model for testing can be created as follows.

In [5]:
model = {}

#### 2.3.1. Model options

Information regarding the method, its order and the domain dimension.

In [6]:
model["opts"] = {
        "method": "KMV",  # either CG or KMV
        "quadratrue": "KMV", # Equi or KMV
        "degree": 1,  # p order
        "dimension": 2,  # dimension
        "regularization": False,  # regularization is on?
        "gamma": 1e-5, # regularization parameter
    }

#### 2.3.2. Model parallelism

It is recommended that the parallelism is settled to *spatial*.

In [7]:
model["parallelism"] = {
        "type": "spatial",  # options: automatic (same number of cores for evey processor) or spatial
    }

#### 2.3.3. Model mesh

The domain dimensions must be defined, as well as the mesh and velocity files imported.

In [10]:
model["mesh"] = {
            "Lz": 1.0,  # depth in km - always positive
            "Lx": 1.0,  # width in km - always positive
            "Ly": 0.0,  # thickness in km - always positive
            "meshfile": "not_used.msh",
            "initmodel": "not_used.hdf5",
            "truemodel": "not_used.hdf5",
        }

#### 2.3.4. Model boundary conditions

`spyro` can use a perfectly matched layer (PML), an artificial absorbing layer for wave equations, commonly used to truncate computational regions in numerical methods to simulate problems with open boundaries. 

In [8]:
# Specify a 250-m Absorbing Boundary Layer (ABL) on the three sides of the domain to damp outgoing waves.
model["BCs"] = {
    "status": True,  # True or False, used to turn on any type of BC
    "method": "PML", # either PML or Damping, used to turn on any type of BC
    "outer_bc": "non-reflective", #  none or non-reflective (outer boundary condition)
    "damping_type": "polynomial",  # polynomial, hyperbolic, shifted_hyperbolic
    "exponent": 2,  # damping layer has a exponent variation
    "cmax": 1.5,  # maximum acoustic wave velocity in PML - km/s
    "R": 1e-6,  # theoretical reflection coefficient
    "lz": 0.7,  # thickness of the PML in the z-direction (km) - always positive
    "lx": 0.7,  # thickness of the PML in the x-direction (km) - always positive
    "ly": 0.0,  # thickness of the PML in the y-direction (km) - always positive
}

#### 2.3.5. Model acquisition

The user must inform some parameters regarding the experimental setup of data acquisition.

In [9]:
model["acquisition"] = {
            "source_type": "Ricker",
            "num_sources": 1,
            "source_pos": [(0.5, 0.5)],
            "frequency": 10.0,
            "delay": 1.0,
            "num_receivers": 10,
            "receiver_locations": spyro.create_transect(
            (0.25, 0.2), (0.25, 0.8), 10
            ),
        }

#### 2.3.6. Model time domain

The last part of the model is to create the time axis.

In [10]:
model["aut_dif"] = {
        "status": True, 
    }

model["timeaxis"] = {
    "t0": 0.0,  #  Initial time for event
    "tf": 1.0,  # Final time for event (for test 7)
    "dt": 0.001,  # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
    "amplitude": 1,  # the Ricker has an amplitude of 1.
    "nspool":  2000,  # (20 for dt=0.00050) how frequently to output solution to pvds
    "fspool": 1,  # how frequently to save solution to RAM
}

#### 2.3.7. Mesh generation

Due to the fact that a simple mesh is generated, the process can be done with a `firedrake` function:

In [19]:
mesh = fire.RectangleMesh(100, 100, 1.5, 1.5) # to test FWI, mesh aligned with interface

#### 2.3.8. Models extraction

The exact and guess models can be extracted from the function:

In [20]:
def _make_vp_pml(V, mesh, v0=1.5, v1=4.0):
    """Create a half space"""
    z, x = fire.SpatialCoordinate(mesh)
    velocity = fire.conditional(z < 0.5, v0, v1)
    vp       = fire.Function(V, name="vp").interpolate(velocity)
    fire.File("exact_vel.pvd").write(vp)

    return vp

Then, it is defined and extracted:

In [21]:
element  = spyro.domains.space.FE_method(
            mesh, model["opts"]["method"], 
            model["opts"]["degree"]
            )
V        = fire.FunctionSpace(mesh, element)
vp_exact = _make_vp_pml(V, mesh)
vp_guess = _make_vp_pml(V, mesh, v0=1.5, v1=1.5)

#### 2.3.9. Model reading

The next step is to read the model file and apply some functions to it. Also, the water domain is identified.

In [25]:
if COMM_WORLD.rank == 0:
    print(f"The mesh has {V.dim()} degrees of freedom")

sources = spyro.Sources(model, mesh, V, comm)
receivers = spyro.Receivers(model, mesh, V, comm)
wavelet = spyro.full_ricker_wavelet(
    dt=model["timeaxis"]["dt"],
    tf=model["timeaxis"]["tf"],
    freq=model["acquisition"]["frequency"],
)
if comm.ensemble_comm.rank == 0:
    control_file = File(outdir + "control.pvd", comm=comm.comm)
    grad_file = File(outdir + "grad.pvd", comm=comm.comm)
quad_rule = finat.quadrature.make_quadrature(
    V.finat_element.cell, V.ufl_element().degree(), "KMV"
)
dxlump = dx(rule=quad_rule)

water = np.where(vp_exact.dat.data[:] < 1.51)

The mesh has 10201 degrees of freedom


### 2.4. Classes and functions

Some functions and classes are useful for dealing with the FWI. They are presented in the following sections.

#### 2.4.1. Inner product

In [20]:
class L2Inner(object):
    def __init__(self):
        self.A = assemble(
            TrialFunction(V) * TestFunction(V) * dxlump, mat_type="matfree"
        )
        self.Ap = as_backend_type(self.A).mat()

    def eval(self, _u, _v):
        upet = as_backend_type(_u).vec()
        vpet = as_backend_type(_v).vec()
        A_u = self.Ap.createVecLeft()
        self.Ap.mult(upet, A_u)
        return vpet.dot(A_u)

#### 2.4.2. Gradient regularization

In [21]:
def regularize_gradient(vp, dJ):
    """Tikhonov regularization"""
    m_u = TrialFunction(V)
    m_v = TestFunction(V)
    mgrad = m_u * m_v * dx(rule=qr_x)
    ffG = dot(grad(vp), grad(m_v)) * dx(rule=qr_x)
    G = mgrad - ffG
    lhsG, rhsG = lhs(G), rhs(G)
    gradreg = Function(V)
    grad_prob = LinearVariationalProblem(lhsG, rhsG, gradreg)
    grad_solver = LinearVariationalSolver(
        grad_prob,
        solver_parameters={
            "ksp_type": "preonly",
            "pc_type": "jacobi",
            "mat_type": "matfree",
        },
    )
    grad_solver.solve()
    dJ += gradreg
    return dJ

#### 2.4.3. Object

In [23]:
class Objective(ROL.Objective):
    def __init__(self, inner_product):
        ROL.Objective.__init__(self)
        self.inner_product = inner_product
        self.p_guess = None
        self.misfit = 0.0
        self.p_exact_recv = spyro.io.load_shots(model, comm)

    def value(self, x, tol):
        """Compute the functional"""
        J_total = np.zeros((1))
        self.p_guess, p_guess_recv = spyro.solvers.forward(
            model,
            mesh,
            comm,
            vp,
            sources,
            wavelet,
            receivers,
        )
        self.misfit = spyro.utils.evaluate_misfit(
            model, p_guess_recv, self.p_exact_recv
        )
        J_total[0] += spyro.utils.compute_functional(model, self.misfit, velocity=vp)
        J_total = COMM_WORLD.allreduce(J_total, op=MPI.SUM)
        J_total[0] /= comm.ensemble_comm.size
        if comm.comm.size > 1:
            J_total[0] /= comm.comm.size

        if COMM_WORLD.rank == 0:
            mem.write(str(get_memory_usage()))
            func.write(str(J_total[0]))
            mem.write("\n")
            func.write("\n")

        return J_total[0]

    def gradient(self, g, x, tol):
        """Compute the gradient of the functional"""
        dJ = Function(V, name="gradient")
        dJ_local = spyro.solvers.gradient(
            model,
            mesh,
            comm,
            vp,
            receivers,
            self.p_guess,
            self.misfit,
        )
        if comm.ensemble_comm.size > 1:
            comm.allreduce(dJ_local, dJ)
        else:
            dJ = dJ_local
        dJ /= comm.ensemble_comm.size
        if comm.comm.size > 1:
            dJ /= comm.comm.size
        # regularize the gradient if asked.
        if model["opts"]["regularization"]:
            dJ = regularize_gradient(vp, dJ)
        # mask the water layer
        dJ.dat.data[water] = 0.0
        # Visualize
        if comm.ensemble_comm.rank == 0:
            grad_file.write(dJ)
        g.scale(0)
        g.vec += dJ

    def update(self, x, flag, iteration):
        vp.assign(Function(V, x.vec, name="velocity"))
        # If iteration reduces functional, save it.
        if iteration >= 0:
            if comm.ensemble_comm.rank == 0:
                control_file.write(vp)

### 2.5. Parameter definition and code running

In [24]:
paramsDict = {
    "General": {"Secant": {"Type": "Limited-Memory BFGS", "Maximum Storage": 10}},
    "Step": {
        "Type": "Augmented Lagrangian",
        "Augmented Lagrangian": {
            "Subproblem Step Type": "Line Search",
            "Subproblem Iteration Limit": 5.0,
        },
        "Line Search": {"Descent Method": {"Type": "Quasi-Newton Step"}},
    },
    "Status Test": {
        "Gradient Tolerance": 1e-16,
        "Iteration Limit": 100,
        "Step Tolerance": 1.0e-16,
    },
}

params = ROL.ParameterList(paramsDict, "Parameters")

inner_product = L2Inner()

obj = Objective(inner_product)

u = Function(V, name="velocity").assign(vp)
opt = FeVector(u.vector(), inner_product)

# Add control bounds to the problem (uses more RAM)
xlo = Function(V)
xlo.interpolate(Constant(1.0))
x_lo = FeVector(xlo.vector(), inner_product)

xup = Function(V)
xup.interpolate(Constant(5.0))
x_up = FeVector(xup.vector(), inner_product)

bnd = ROL.Bounds(x_lo, x_up, 1.0)

algo = ROL.Algorithm("Line Search", params)

algo.run(opt, obj, bnd)

if comm.ensemble_comm.rank == 0:
    File("res.pvd", comm=comm.comm).write(vp)


if COMM_WORLD.rank == 0:
    func.close()
    mem.close()

NameError: name 'V' is not defined

## 3. References

[1] ROBERTS, K. J. et al. spyro: a firedrake-based wave propagation and full waveform
inversion finite element solver. Geoscientific Model Development Discussions, v. 2021, p.
1–47, 2021. Available in: <https://gmd.copernicus.org/preprints/gmd-2021-363/> 