In [None]:
import numpy as np
import os
import sys
from typing import Union, Callable

import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
import sys
sys.path.append('..')
import helpers

In [None]:
import epde

mpl.rcParams.update(mpl.rcParamsDefault)
plt.rcParams['text.usetex'] = True

SMALL_SIZE = 12
mpl.rc('font', size=SMALL_SIZE)
mpl.rc('axes', titlesize=SMALL_SIZE)

In [None]:
def Lotka_Volterra_by_RK(initial : tuple, timestep : float, steps : int, alpha : float, 
                         beta : float, delta : float, gamma : float):
    res = np.full(shape = (steps, 2), fill_value = initial, dtype=np.float64)
    for step in range(steps-1):
        # print(res[step, :])
        k1 = alpha * res[step, 0] - beta * res[step, 0] * res[step, 1]; x1 = res[step, 0] + timestep/2. * k1
        l1 = delta * res[step, 0] * res[step, 1] - gamma * res[step, 1]; y1 = res[step, 1] + timestep/2. * l1

        k2 = alpha * x1 - beta * x1 * y1; x2 = res[step, 0] + timestep/2. * k2
        l2 = delta * x1 * y1 - gamma * y1; y2 = res[step, 1] + timestep/2. * l2

        k3 = alpha * x2 - beta * x2 * y2
        l3 = delta * x2 * y2 - gamma * y1
        
        x3 = res[step, 0] + timestep * k1 - 2 * timestep * k2 + 2 * timestep * k3
        y3 = res[step, 1] + timestep * l1 - 2 * timestep * l2 + 2 * timestep * l3
        k4 = alpha * x3 - beta * x3 * y3
        l4 = delta * x3 * y3 - gamma * y3
        
        res[step+1, 0] = res[step, 0] + timestep / 6. * (k1 + 2 * k2 + 2 * k3 + k4)
        res[step+1, 1] = res[step, 1] + timestep / 6. * (l1 + 2 * l2 + 2 * l3 + l4)
    return res

steps_num = 301; step = 1./steps_num
t = np.arange(start = 0, stop = step * steps_num, step = step)
solution = Lotka_Volterra_by_RK(initial=(4., 2.), timestep=step, steps=steps_num, 
                                alpha=20., beta=20., delta=20., gamma=20.)

In [None]:
def epde_discovery(t, x, y, use_ann = False):
    dimensionality = x.ndim - 1
    
    '''
    Подбираем Парето-множество систем дифф. уравнений.
    '''
    epde_search_obj = epde.EpdeSearch(use_solver = False, dimensionality = dimensionality, 
                                          boundary = 25, coordinate_tensors = [t,])
    
    if use_ann:
        epde_search_obj.set_preprocessor(default_preprocessor_type='ANN',
                                         preprocessor_kwargs={'epochs_max' : 25000})
    else:
        epde_search_obj.set_preprocessor(default_preprocessor_type='poly',
                                         preprocessor_kwargs={'use_smoothing' : True, 'sigma' : 1, 
                                                              'polynomial_window' : 3, 'poly_order' : 3})
    
    popsize = 20
    epde_search_obj.set_moeadd_params(population_size = popsize, training_epochs=30)

    factors_max_number = {'factors_num' : [1, 2], 'probas' : [0.5, 0.5]}
    
    epde_search_obj.fit(data=[x, y], variable_names=['u', 'v'], max_deriv_order=(1,),
                        equation_terms_max_number=5, data_fun_pow = 2,
                        equation_factors_max_number=factors_max_number,
                        eq_sparsity_interval=(1e-8, 1e-4))       
    return epde_search_obj

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))
ax1.plot(t, solution[:, 0], color= 'k', label = 'Prey')
ax1.plot(t, solution[:, 1], color= 'r', label = 'Hunters')
ax1.set_xlabel('Time')
ax1.set_ylabel('Population size')
ax1.legend(loc = 'upper right')
ax1.grid()

ax2.set_xlabel('Prey')
ax2.set_ylabel('Hunters')
ax2.plot(solution[:, 0], solution[:, 1], color= 'k')
ax2.grid()
plt.show()

In [None]:
t_max = 150
t_train = t[:t_max]; t_test = t[t_max:]

x = solution[:t_max, 0]; x_test = solution[t_max:, 0]
y = solution[:t_max, 1]; y_test = solution[t_max:, 1]

epde_search_obj = epde_discovery(t_train, x, y, False)
epde_search_obj.equations(only_print = True, num = 1)  

In [None]:
disc_eq = epde_search_obj.get_equations_by_complexity([2.5, 2.5])[0]

In [None]:
pred_u_v = epde_search_obj.predict(system=disc_eq, boundary_conditions=None, mode='mat')

In [None]:
plt.plot(t_train, x, '+', label = 'preys, odeint')
plt.plot(t_train, y, '*', label = "predators, odeint")
plt.plot(t_train, pred_u_v[:, 0], color = 'b', label='preys, NN')
plt.plot(t_train, pred_u_v[:, 1], color = 'r', label='predators, NN')
plt.xlabel('Time t, [days]')
plt.ylabel('Population')
plt.grid()
plt.legend(loc='upper right')
plt.show()