In [None]:
from firedrake import (
    RectangleMesh,
    FunctionSpace,
    Function,
    SpatialCoordinate,
    conditional,
    File,
)

In [None]:
from firedrake import *
# from firedrake_adjoint import *
import spyro
import numpy as np
import math
import numpy                  as np
import matplotlib.pyplot      as plot
import matplotlib.ticker      as mticker  
from matplotlib               import cm, ticker
from mpl_toolkits.axes_grid1  import make_axes_locatable

In [None]:
model = {}

# Choose method and parameters
model["opts"] = {
    "method": "KMV",  # either CG or KMV
    "quadratrue": "KMV", # Equi or KMV
    "degree": 1,  # p order
    "dimension": 2,  # dimension
}

# Number of cores for the shot. For simplicity, we keep things serial.
# spyro however supports both spatial parallelism and "shot" parallelism.
model["parallelism"] = {
    "type": "off",  # options: automatic (same number of cores for evey processor), custom, off.
    "custom_cores_per_shot": [],  # only if the user wants a different number of cores for every shot.
    # input is a list of integers with the length of the number of shots.
}

# Define the domain size without the PML. Here we'll assume a 0.75 x 1.50 km
# domain and reserve the remaining 250 m for the Perfectly Matched Layer (PML) to absorb
# outgoing waves on three sides (eg., -z, +-x sides) of the domain.
model["mesh"] = {
    "Lz": 0.75,  # depth in km - always positive
    "Lx": 1.5,  # width in km - always positive
    "Ly": 0.0,  # thickness in km - always positive
    "meshfile": "not_used.msh",
    "initmodel": "not_used.hdf5",
    "truemodel": "not_used.hdf5",
}

# Specify a 250-m PML on the three sides of the domain to damp outgoing waves.
model["PML"] = {
    "status": False,  # True or false
    "outer_bc": "non-reflective",  #  None or non-reflective (outer boundary condition)
    "damping_type": "polynomial",  # polynomial, hyperbolic, shifted_hyperbolic
    "exponent": 2,  # damping layer has a exponent variation
    "cmax": 4.7,  # maximum acoustic wave velocity in PML - km/s
    "R": 1e-6,  # theoretical reflection coefficient
    "lz": 0.25,  # thickness of the PML in the z-direction (km) - always positive
    "lx": 0.25,  # thickness of the PML in the x-direction (km) - always positive
    "ly": 0.0,  # thickness of the PML in the y-direction (km) - always positive
}

# Create a source injection operator. Here we use a single source with a
# Ricker wavelet that has a peak frequency of 8 Hz injected at the center of the mesh.
# We also specify to record the solution at 101 microphones near the top of the domain.
# This transect of receivers is created with the helper function `create_transect`.
model["acquisition"] = {
    "source_type": "Ricker",
    "num_sources": 1,
    "source_pos": [(0.1, 0.5)],
    "frequency": 3.0,
    "delay": 1.0,
    "num_receivers": 100,
    "receiver_locations": spyro.create_transect(
        (0.10, 0.1), (0.10, 0.9), 100
    ),
}

# Simulate for 2.0 seconds.
model["timeaxis"] = {
    "t0": 0.0,  #  Initial time for event
    "tf": 1.00,  # Final time for event
    "dt": 0.001,  # timestep size
    "amplitude": 1,  # the Ricker has an amplitude of 1.
    "nspool": 100,  # how frequently to output solution to pvds
    "fspool": 1,  # how frequently to save solution to RAM
}




In [None]:
mesh = RectangleMesh(100, 100, 1.0, 1.0)
# V    = FunctionSpace(mesh, family='CG', degree=2)
# Create the computational environment
comm = spyro.utils.mpi_init(model)

element = spyro.domains.space.FE_method(
    mesh, model["opts"]["method"], model["opts"]["degree"]
)
V = FunctionSpace(mesh, element)


In [None]:
x, y = SpatialCoordinate(mesh)
velocity = conditional(x > 0.35, 1.5, 3.0)

vp = Function(V, name="vp").interpolate(velocity)


In [None]:
sources   = spyro.Sources(model, mesh, V, comm).create()
receivers = spyro.Receivers(model, mesh, V, comm).create()


In [None]:
solver      = spyro.solvers.Leapfrog
usol, usol_rec = solver(model, mesh, comm, vp,sources, receivers, source_num=0)

In [None]:
misfit = usol_rec

J_total = spyro.utils.compute_functional(model, comm, misfit)


In [None]:
solver     = spyro.solvers.Leapfrog_adjoint
dJdC_local = solver(model, mesh, comm, vp, receivers, usol, misfit)

In [None]:
from mpi4py import MPI
# sum over all ensemble members
dJdC_local.dat.data[:] = comm.ensemble_comm.allreduce(
    dJdC_local.dat.data[:], op=MPI.SUM
)


fig, axes = plot.subplots()
axes.set_aspect('equal')
colors = firedrake.tripcolor(dJdC_local, axes=axes, shading='gouraud', cmap="jet")

fig.colorbar(colors);
plot.savefig('grad.png',dpi=100,format='png')