In [None]:
from FEnicS_base import *
import dolfinx
from dolfinx import mesh, fem, io, nls
from dolfinx.fem import VectorFunctionSpace
from mpi4py import MPI
import numpy as np
import ufl
import matplotlib.pyplot as plt

import jsonpickle as jp
import shutil
import re
import inspect
from tqdm import tqdm

def function_plots(
    functions,
    fig=None,
    ax=None,
    show_points=False,
):
    """Create plot from fem.Function
    
    Args:
        fig (plt.Figure): Figure
        lists (fem.Function, str): List of (u, title)
        points_on (bool): If true create scatter
    """
    if (fig or ax) is None:
        fig, ax = plt.subplots(facecolor='White')
        fig.set_size_inches(16, 8)
    for func in functions:
        x = func.function_space.tabulate_dof_coordinates()[:, 0]
        y = func.x.array
        cord = np.array([x, y])
        cord = cord[:, np.argsort(cord[0])]
        ax.plot(cord[0], cord[1], label=func.name, linewidth=1)
        if show_points: ax.scatter(cord[0], cord[1], s=0.5)
    ax.legend(
        bbox_to_anchor=(1.01, 0.5),
        borderaxespad=0,
        loc='center left',
    )
    return ax


In [None]:
# Components
class Light_collection:

    def __init__(self, x: SpatialCoordinate, x0, slope):
        self.x = x
        self.x0 = x0
        self.slope = slope

    def create(self, kind: str):
        assert kind in Light_collection.get_kinds(False), 'Not implemented method'
        return getattr(self, kind)()

    @staticmethod
    def _kind_ready(func=None, get=False, l=set()):
        if not get:
            l.add(func.__name__)
            return func
        else:
            return l

    @classmethod
    def get_kinds(cls, view=True):
        kinds = cls._kind_ready(get=True)
        if not view:
            return kinds
        else:
            print('\n'.join(kinds))

    def _singP(self):
        return (1 + ufl.sign(self.slope)) / 2

    def _singM(self):
        return (1 - ufl.sign(self.slope)) / 2

    @_kind_ready
    def step(self):
        return conditional(self.x0 <= self.x, self._singP(), self._singM())

    @_kind_ready
    def sigmoid(self):
        a = self.slope * 5
        return 1 / (1 + exp(-a * (self.x - self.x0)))

    @_kind_ready
    def trapsharp(self):
        a = self.slope
        res = conditional(
            ufl.And(
                self.x0 - 1 / (abs(2 * a)) <= self.x,
                self.x < self.x0 + 1 / (abs(2 * a))
            ),
            a * (self.x - self.x0) + 0.5,
            0,
        )
        res += conditional(
            self.x0 + 1 / (2*a) <= self.x,
            self._singP(),
            self._singM(),
        )
        return res

    @_kind_ready
    def parab(self):
        a = self.slope * 5
        res = conditional(
            ufl.And(
                self.x0 - 1 / sqrt(abs(2 * a)) <= self.x,
                self.x < self.x0,
            ),
            a * (self.x - self.x0 + 1 / sqrt(abs(2 * a)))**2 + self._singM(),
            0
        )
        res += conditional(
            ufl.And(
                self.x0 <= self.x,
                self.x < self.x0 + 1 / sqrt(abs(2 * a)),
            ),
            -a * (self.x - self.x0 - 1 / sqrt(abs(2 * a)))**2 + self._singP(),
            0
        )
        res += conditional(
            self.x0 + 1 / (ufl.sign(a) * sqrt(abs(2 * a))) <= self.x,
            self._singP(),
            self._singM()
        )
        return res


def create_equation():

    def create_facets():
        set_connectivity(DOMAIN)
        ds = Measure("ds", domain=DOMAIN)
        return ds

    def inside_flux(
        a_NM=CONST['A_NM'],
        b_PM=CONST['B_PM'],
        e_NP=CONST['E_NP'],
    ):
        flux_N = 0
        flux_N += -a_NM * grad(n)
        flux_N += +a_NM * p * grad(n)
        flux_N += -e_NP * p * grad(n)
        flux_N += -a_NM * n * grad(p)
        flux_N += +e_NP * n * grad(p)

        flux_P = 0
        flux_P += -b_PM * grad(p)
        flux_P += +b_PM * n * grad(p)
        flux_P += -e_NP * n * grad(p)
        flux_P += -b_PM * p * grad(n)
        flux_P += +e_NP * p * grad(n)
        return flux_N, flux_P

    dx = Measure('cell', subdomain_id='everywhere')
    dt = 0.01
    ds = create_facets()

    u, v = TestFunctions(SPACE)
    n, p = INDIC['N'], INDIC['P']
    n0, p0 = INDIC_0['N'], INDIC_0['P']
    qN, qP = inside_flux()

    equationN = (1/dt) * (n-n0) * u * dx
    equationN += -(qN|inner|grad(u)) * dx
    equationN += u * SURFACE['N'] * ds

    equationP = (1/dt) * (p-p0) * v * dx
    equationP += -(qP|dot|grad(v)) * dx
    equationP += -LIHGT * CONST['REACTION'] * v * dx
    equationP += v * SURFACE['P'] * ds

    return {'N': equationN, 'P': equationP}


DOMAIN = mesh.create_interval(nx=100, comm=MPI.COMM_WORLD, points=[0, 1])
element = {
    'N':ufl.FiniteElement(
        family='CG',
        cell=DOMAIN.ufl_cell(),
        degree=2),
    'P':ufl.VectorElement(
        family='CG',
        cell=DOMAIN.ufl_cell(),
        degree=2)
    } # yapf: disable
SPACE = FunctionSpace(
    mesh=DOMAIN,
    element=ufl.MixedElement(element['N'], element['P']),
)
SUBSPACE = {
    'N': SPACE.sub(0).collapse()[0],
    'P': SPACE.sub(1).collapse()[0],
}
x = SpatialCoordinate(SPACE)[0]

FUNC, FUNC0 = Function(SPACE), Function(SPACE)
SUB_FUNC = {
    'N': FUNC.sub(0),
    'P': FUNC.sub(1),
}
INDIC = {
    'N': split(FUNC)[0],
    'P': split(FUNC)[1],
}
INDIC_0 = {
    'N': split(FUNC0)[0],
    'P': split(FUNC0)[1],
}

TIME = Constant(SUBSPACE['N'], 0)
LIHGT = Light_collection(x=x, x0=0.2, slope=100).create(kind='close')
LIHGT *= Light_collection(x=x, x0=0.5, slope=-100).create(kind='close')

# Consts
CONST= {
    'd': 0.01,
}

# Surface
SURFACE = {
    'N': 0,
    'P': 0,
}
BCS = []

EQUATION = create_equation()
# FIXME: clear reset KSP solver
PROBLEM = []


def set_initial():
    # TODO: reset bcs
    SUB_FUNC['N'].interpolate(Function(SUBSPACE['N'], 0.1))
    SUB_FUNC['P'].interpolate(Function(SUBSPACE['P'], 0))

    FUNC.x.scatter_forward()
    FUNC0.interpolate(FUNC)

    TIME.value = 0

    PROBLEM = NonlinearProblem(
        F=sum(EQUATION.values()),
        bcs=[],
        u=FUNC,
        solve_options={
            'convergence': 'incremental', 'tolerance': 1E-6
        },
        petsc_options={
            'ksp_type': 'preonly',
            'pc_type': 'lu',
            'pc_factor_mat_solver_type': 'mumps'
        },
        form_compiler_params={},
        jit_params={},
    )


set_initial()

In [None]:
def solve(n_steps, reset=True, save=False):

    def _set_next():
        # TODO: interpolate bcs
        FUNC0.interpolate(FUNC)
        TIME.value += 0.01

    def _solve_default(steps):
        for step in steps:
            _set_next()
            steps.set_description(f'Solving PDE. Time:{TIME.value:.2f}')
            PROBLEM.solve()

    if reset: set_initial()
    steps_line = tqdm(
        desc=f'Solving PDE. Time:{TIME.value:.3f}',
        iterable=np.arange(0, n_steps, dtype=int),
    )
    _solve_default(steps_line)




In [None]:
solve(
     # save=True,
     # n_steps=2000,
    reset=True,
)
function_plots([
    Function(SUBSPACE['N'], SUB_FUNC['N']),
    Function(SUBSPACE['N'], SUB_FUNC['N'])
])
