# Implementing pySDC problems using Firedrake

## What is Firedrake?
[Firedrake](https://www.firedrakeproject.org) is an elaborate framework for discretizing PDEs using the finite element method (FEM), sharing many characteristics with FEniCS.
In very basic terms, Firedrake allows you to write code very close to mathematical equations, which describes a PDE in weak form.
Then you hand this description to a solver in Firedrake and wait for the result - no headaches necessary!
This is an immensely powerful tool for setting up complicated problems on complicated meshes while relying on a lot of existing automation rather than coding everything yourself.
For a few basic examples and explanations see again the [Firedrake documentation](https://www.firedrakeproject.org/documentation.html).


## Installing Firedrake
Installing Firedrake can be non-trivial, but it keeps improving.
Please consult the respective [documentation](https://www.firedrakeproject.org/install.html#id9).
If you get stuck, open a discussion on the [GitHub page](https://github.com/firedrakeproject/firedrake/discussions), where the Firedrake developers are responsive and helpful.

To run this notebook, you need make a kernel with your Firedrake installation.
Activate a virtual environment with Firedrake installed, then install pySDC and jupyter stuff with
```
pip install -e <path-to-pySDC>
pip install ipykernel
pip install jupyter
```

Afterwards, generate the kernel and start the notebook:
```
python -m ipykernel install --user --name=pySDC_Firedrake
python -m jupyter lab
```

## Coupling pySDC to Firedrake
pySDC is already coupled to Firedrake, so don't worry.
However, we will briefly go through the existing coupling here to illustrate how to couple pySDC to any library.
pySDC is generally written to be agnostic to the datatype, but coupling requires coding up a new datatype which respects only a few properties, has a rule for generating empty data containers, and which can take care of communication.

With `u`, `v` data as used in the library you want to couple to and `a` a float, the following must hold:
 - `abs(u)` must return a float with the norm across the entire spatial domain
 - `a*u + v - v*u` must to be implemented and the result must be of the same type as `u`.
  
For communication, the following functions familiar from MPI need to implemented in the datatypes:
 - `bcast`: Broadcast data
 - `Isend`: Non-blocking send of the data
 - `Irecv`: Non-blocking receive of the data

So, coupling pySDC to a library typically means writing a wrapper for the data or subclassing the datatype and adding only the above mentioned functionality.

In the case of Firedrake, we write a wrapper for `Firedrake.Function` which we call `firedrake_mesh` here.
We start with the `__init__` function, which is called during object instantiation.
In the problem classes, we will define a function that we can use to instantiate a respective `firedrake_mesh`.
Here, we will make use of `__mro__`, which returns the [method resolution order](https://docs.python.org/3/howto/mro.html) of the object we call this on, which tells us the tree of inheritance of the object.
The inheritance relationship we are looking for here depends on some intricacies of Firedrake that you needn't worry about now.

In [1]:
import jdc  # required to split the class definition into multiple cells...

import firedrake as fd

from pySDC.core.errors import DataError
from pySDC.helpers.firedrake_ensemble_communicator import FiredrakeEnsembleCommunicator


class firedrake_mesh(object):
    """
    Wrapper for firedrake function data.

    Attributes:
        functionspace (firedrake.Function): firedrake data
    """

    def __init__(self, init, val=0.0):
        if fd.functionspaceimpl.WithGeometry in type(init).__mro__:
            self.functionspace = fd.Function(init)
            self.functionspace.assign(val)
        elif fd.Function in type(init).__mro__:
            self.functionspace = fd.Function(init)
        elif type(init) == firedrake_mesh:
            self.functionspace = init.functionspace.copy(deepcopy=True)
        else:
            raise DataError('something went wrong during %s initialization' % type(init))
   

Next, we make this a wrapper for `firedrake.Function` via the `__getattr__` method.
If you call `a.key` on an object `a` and `a` does not have an attribute or function `key`, this `a.__getattr__(key)` will be called, which makes it easy to pass on requests.

In [2]:
%%add_to firedrake_mesh

def __getattr__(self, key):
    return getattr(self.functionspace, key)

Next up, we define addition, subtraction and right muplication for the new datatype:

In [3]:
%%add_to firedrake_mesh

def __add__(self, other):
    if isinstance(other, type(self)):
        me = firedrake_mesh(other)
        me.functionspace.assign(self.functionspace + other.functionspace)
        return me
    else:
        raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))

def __sub__(self, other):
    if isinstance(other, type(self)):
        me = firedrake_mesh(other)
        me.functionspace.assign(self.functionspace - other.functionspace)
        return me
    else:
        raise DataError("Type error: cannot add %s to %s" % (type(other), type(self)))

def __rmul__(self, other):
    """
    Overloading the right multiply by scalar factor

    Args:
        other (float): factor
    Raises:
        DataError: if other is not a float
    Returns:
        fenics_mesh: copy of original values scaled by factor
    """

    try:
        me = firedrake_mesh(self)
        me.functionspace.assign(other * self.functionspace)
        return me
    except TypeError as e:
        raise DataError("Type error: cannot multiply %s to %s" % (type(other), type(self))) from e

Now, we take care of the norm, which is as simple as calling the Firedrake function that we want:

In [4]:
%%add_to firedrake_mesh

def __abs__(self):
    """
    Overloading the abs operator for mesh types

    Returns:
        float: L2 norm
    """

    return fd.norm(self.functionspace, 'L2')

Finally: Communication.
Firedrake has "ensemble communicators", which were built by Josh for space-time parallelism.
In pySDC, we have written a wrapper for that that makes a few things easier, but essentially, communication is again handled with Firedrake functions.

In [5]:
%%add_to firedrake_mesh

def isend(self, dest=None, tag=None, comm=None):
    """
    Routine for sending data forward in time (non-blocking)

    Args:
        dest (int): target rank
        tag (int): communication tag
        comm: communicator

    Returns:
        request handle
    """
    assert (
        type(comm) == FiredrakeEnsembleCommunicator
    ), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
    return comm.Isend(self.functionspace, dest=dest, tag=tag)

def irecv(self, source=None, tag=None, comm=None):
    """
    Routine for receiving in time

    Args:
        source (int): source rank
        tag (int): communication tag
        comm: communicator

    Returns:
        None
    """
    assert (
        type(comm) == FiredrakeEnsembleCommunicator
    ), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
    return comm.Irecv(self.functionspace, source=source, tag=tag)

def bcast(self, root=None, comm=None):
    """
    Routine for broadcasting values

    Args:
        root (int): process with value to broadcast
        comm: communicator

    Returns:
        broadcasted values
    """
    assert (
        type(comm) == FiredrakeEnsembleCommunicator
    ), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(comm)}'
    comm.Bcast(self.functionspace, root=root)
    return self

As you can see, we didn't do a whole lot.
Adding a coupling is just about bridging the gap between the pySDC interface and the library interface, not about implementing any new functionality.
You will need intimate knowledge of the library you are coupling to, you may need to do some workarounds that are not optimal performance wise, but it is no Hexenwerk (rocket science in German).
You definitely needn't know a lot about pySDC.

## Using Firedrake to discretize the heat equation in pySDC

In [6]:
from pySDC.core.problem import Problem, WorkCounter
import numpy as np

class HeatEquation(Problem):
    dtype_u = firedrake_mesh
    dtype_f = firedrake_mesh
    
    def __init__(self, n, nu, order, comm):
        # prepare Firedrake mesh and function space
        self.mesh = fd.UnitIntervalMesh(n, comm=comm)
        self.V = fd.FunctionSpace(self.mesh, "CG", order)

        # prepare pySDC problem class infrastructure by passing the function space to super init
        super().__init__(self.V)
        self._makeAttributeAndRegister(
            'n', 'nu', 'order', 'comm', localVars=locals(), readOnly=True
        )

        
        # prepare caches and IO variables for solvers
        self.solvers = {}
        self.tmp_in = fd.Function(self.V)
        self.tmp_out = fd.Function(self.V)
        
        self.work_counters['solver_setup'] = WorkCounter()
        self.work_counters['solves'] = WorkCounter()
        self.work_counters['rhs'] = WorkCounter()

In [7]:
%%add_to HeatEquation

def eval_f(self, u, t):
    # construct and cache a solver for evaluating the Laplacian
    if not hasattr(self, '__solv_eval_f'):
        v = fd.TestFunction(self.V)
        u_trial = fd.TrialFunction(self.V)

        a = u_trial * v * fd.dx
        L = -fd.inner(self.nu * fd.nabla_grad(self.tmp_in), fd.nabla_grad(v)) * fd.dx

        bcs = [fd.bcs.DirichletBC(self.V, fd.Constant(0), area) for area in [1, 2]]

        prob = fd.LinearVariationalProblem(a, L, self.tmp_out, bcs=bcs)
        self.__solv_eval_f = fd.LinearVariationalSolver(prob)

    # copy the solution we want to evaluate at into the input buffer
    self.tmp_in.assign(u.functionspace)

    # perform the solve using the cached solver
    self.__solv_eval_f.solve()

    # instantiate an empty data container
    me = self.dtype_f(self.init)

    # copy the result of the solver from the output buffer to the variable this function returns
    me.assign(self.tmp_out)

    self.work_counters['rhs']()

    return me

In [8]:
%%add_to HeatEquation

def solve_system(self, rhs, dt, *args, **kwargs):
    r"""
    Linear solver for :math:`(M - dt * nu * Lap) u = rhs`.
    """

    # construct and cache a solver for the current dt (preconditioner entry times step size)
    if dt not in self.solvers.keys():

        u = fd.TrialFunction(self.V)
        v = fd.TestFunction(self.V)

        a = u * v * fd.dx + fd.Constant(dt) * fd.inner(self.nu * fd.nabla_grad(u), fd.nabla_grad(v)) * fd.dx
        L = fd.inner(self.tmp_in, v) * fd.dx

        bcs = [fd.bcs.DirichletBC(self.V, fd.Constant(0), area) for area in [1, 2]]

        prob = fd.LinearVariationalProblem(a, L, self.tmp_out, bcs=bcs)
        self.solvers[dt] = fd.LinearVariationalSolver(prob)

        self.work_counters['solver_setup']()

    # copy solver rhs to the input buffer. Copying also to the output buffer uses it as initial guess
    self.tmp_in.assign(rhs.functionspace)
    self.tmp_out.assign(rhs.functionspace)

    # call the cached solver
    self.solvers[dt].solve()

    # copy from output buffer to return variable
    me = self.dtype_u(self.init)
    me.assign(self.tmp_out)

    self.work_counters['solves']()
    return me

In [9]:
%%add_to HeatEquation

def x(self):
    return fd.SpatialCoordinate(self.mesh)

def u_exact(self, t):
    me = self.u_init
    me.interpolate(np.exp(-self.nu* np.pi**2*t) * fd.sin(np.pi * self.x()[0]))
    return me

In [10]:
from mpi4py import MPI
prob = HeatEquation(n=128, nu=1e-2, order=4, comm=MPI.COMM_WORLD)

u0 = prob.u_exact(0)

f_expect = -prob.nu * np.pi**2 * u0
f = prob.eval_f(u0, 0)
assert abs(f-f_expect) < 1e-8

In [11]:
dt = 1e-2

assert abs(prob.solve_system(u0, dt) - prob.u_exact(dt)) < 1e-6