In [51]:
from dolfinx import fem, mesh, io
import dolfinx.fem.petsc as petsc
import numpy as np
from mpi4py import MPI
import ufl

In [52]:
# Define temporal parameters
T = 1.0  # Final time
num_steps = 50
dt = T / num_steps  # time step size

lam = 1
gamma = dt * lam / (1+lam)
theta = 1/2

In [53]:
Nx = 50
Ny = 50

domain = mesh.create_unit_square(MPI.COMM_WORLD, Nx, Ny, mesh.CellType.quadrilateral)
V = fem.functionspace(domain, ("Lagrange", 1))
t = fem.Constant(domain, 0.0)

X = ufl.SpatialCoordinate(domain)
I_stim = 8 * ufl.pi**2 * lam/(1+lam) * ufl.sin(t) * ufl.cos(2*ufl.pi*X[0]) * ufl.cos(2*ufl.pi*X[1])

# Create initial conditions
def initial_v(x):
    return 0*x[0]

vn = fem.Function(V)
vn.name = "vn"
vn.interpolate(initial_v)

def initial_s(x):
    return -np.cos(2*np.pi * x[0]) * np.cos(2*np.pi * x[1])

sn = fem.Function(V)
sn.name = "sn"
sn.interpolate(initial_s)

vntheta = fem.Function(V)
vntheta.name = "vntheta"

sntheta = fem.Function(V)
sntheta.name = "sntheta"

xdmf = io.XDMFFile(domain.comm, "monodomain.xdmf", "w")
xdmf.write_mesh(domain)
xdmf.write_function(vn, t)

In [54]:
v = ufl.TrialFunction(V)
phi = ufl.TestFunction(V)

a = phi * v * ufl.dx + gamma * theta * ufl.dot(ufl.grad(phi), ufl.grad(v)) * ufl.dx
L = phi * (vntheta + dt * I_stim) * ufl.dx - gamma * (1-theta) * ufl.dot(ufl.grad(phi), ufl.grad(vntheta)) * ufl.dx

In [55]:
compiled_a = fem.form(a)
A = petsc.assemble_matrix(compiled_a)
A.assemble()

compiled_L = fem.form(L)
b = fem.Function(V)

In [56]:
def forward_euler_step(init, dt):
    v = init[0] - dt * init[1]
    s = init[1] + dt * init[0]
    return v, s

In [57]:
from petsc4py import PETSc
solver = PETSc.KSP().create(domain.comm)
solver.setOperators(A)
solver.setType(PETSc.KSP.Type.PREONLY)
solver.getPC().setType(PETSc.PC.Type.LU)

In [58]:
def v_exact(t):
    return lambda x: np.cos(2 * np.pi * x[0]) * np.cos(2 * np.pi * x[1]) * np.sin(t)
v_ex = fem.Function(V)
v_ex.name = "v_ex"
v_ex.interpolate(v_exact(0))
xdmf.write_function(v_ex,0)

while t.value < T:
    # Step 1
    t.value += theta * dt
    vntheta.x.array[:], sntheta.x.array[:] = forward_euler_step([vn.x.array, sn.x.array], theta * dt)

    # Step 2
    b.x.array[:] = 0
    petsc.assemble_vector(b.vector, compiled_L)
    
    solver.solve(b.vector, vn.vector)
    vn.x.scatter_forward()

    # Step 3
    t.value += (1-theta) * dt

    vn.x.array[:], sn.x.array[:] = forward_euler_step([vn.x.array.copy(), sntheta.x.array], (1-theta) * dt)
    v_ex.interpolate(v_exact(t.value))
    xdmf.write_function(vn, t.value)
    xdmf.write_function(v_ex, t.value)
xdmf.close()

In [59]:
def v_exact(t):
    return lambda x: np.cos(2 * np.pi * x[0]) * np.cos(2 * np.pi * x[1]) * np.sin(t)

v_ex = fem.Function(V)
v_ex.interpolate(v_exact(t.value))

comm = vn.function_space.mesh.comm
error = fem.form((vn - v_ex)**2 * ufl.dx)
E = np.sqrt(comm.allreduce(fem.assemble_scalar(error), MPI.SUM))
if comm.rank == 0:
    print(f"L2-error: {E:.2e}")

L2-error: 6.40e-04
