In [None]:
from fenics import *
from numpy.random import random
set_log_level(30)
import matplotlib.pyplot as plt
from mshr import *
import numpy as np

In [None]:
class PredatorPrey(NonlinearProblem):
    def __init__(self, a, L):
        NonlinearProblem.__init__(self)
        self.L = L
        self.a = a
    def F(self, b, x): assemble(self.L, tensor=b)
    def J(self, A, x): assemble(self.a, tensor=A)

In [None]:
"""
Create irregular domain
"""

p0 = Point(0.0, 0.0)
p1 = Point(50,50) 

square = Rectangle(p0, p1)
domain = square  

mesh = generate_mesh(domain,150)
plot(mesh)

In [None]:
U = FiniteElement("CG", mesh.ufl_cell(), 2)
W = FunctionSpace(mesh, U * U)

du   = TrialFunction(W)
q, p = TestFunctions(W)

w = Function(W)
w0 =  Function(W)

In [None]:
# Split mixed functions
dN, dP = split(du)
N, P = split(w)
N0, P0 = split(w0)

dt = 5
T = 1000

In [None]:
class IC(UserExpression):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def eval(self, values, x):
        values[0] = 1.0*random() +0.25
        values[1] = 1.0*random() +0.25
    def value_shape(self): return (2,)

In [None]:
w_init = IC(element=W.ufl_element(), degree=2)

w.interpolate(w_init)
w0.interpolate(w_init)

p1 = plot(P0)
p1.set_cmap("seismic")
plt.title("$P(t=0)$")
plt.colorbar(p1) 
plt.savefig("predator_prey_P_0.png")
plt.show()  

p2 = plot(N0)
p2.set_cmap("gray")
plt.title("$N(t=0)$")
plt.colorbar(p2) 
plt.savefig("predator_prey_N_0.png")
plt.show()

In [None]:
# Set parameters
D_N = 0.015
D_P = 1.0
alpha = 0.65
beta = 0.25
gamma = 0.5

L0 = N*q -N0*q +D_N*inner(grad(N), grad(q))*dt - alpha*N*(1-N)*(N+P)*q*dt + N*P*q*dt
L1 = P*p -P0*p +D_P*inner(grad(P), grad(p))*dt + beta*P*(N+P)*p*dt - gamma*N*P*p*dt
L  = (L0 + L1) *dx

In [None]:
# Compute directional derivative about u in the direction of du
a = derivative(L, w, du)

problem = PredatorPrey(a, L)
solver = NewtonSolver()
solver.parameters["linear_solver"] = "lu"
solver.parameters["convergence_criterion"] = "incremental"
solver.parameters["relative_tolerance"] = 1e-1

In [None]:
t = 0
while t < T:
    t += dt
    w0.vector()[:] = w.vector()
    solver.solve(problem, w.vector())
    N,P = w.split()
    
    #if t % 2 == 0:
    print("t=", t)
    p = plot(P)
    p.set_cmap("seismic")
    plt.title("P(t={})".format(round(t,2)))
    plt.colorbar(p)
    plt.savefig("predator_prey_P_{}.png".format(t))
    plt.show()   

    p1 = plot(N)
    p1.set_cmap("gray")
    plt.title("$N(t={})$".format(round(t,2)))
    plt.colorbar(p1) 
    plt.savefig("predator_prey_N_{}.png".format(t))    
    plt.show()  