In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import sys
import time
import torch

from scipy.integrate import solve_ivp

PROJECT_ROOT = os.path.abspath(
    os.path.join(os.getcwd(), os.pardir)
)
sys.path.append(PROJECT_ROOT)

import sample_points
from tcpinn import TcPINN
from plot import plot_solution

np.random.seed(1)

In this notebook, we consider the following nonlinar ODE system:

\begin{align}
    \frac{dx}{dt} &= 10x - 5xy, \\
    \frac{dy}{dt} &= 3y + xy - 3y^2.
\end{align}

We train a tcPINN to solve this system for $t \in [0, 1]$ and $x_0, y_0 \in [0, 5]$, and investigate its extrapolation behavior.

In [2]:
class ExampleNonlinear(TcPINN):
    """
    A tcPINN implementation of the above non-linear ODE.
    """
    def __init__(
        self, layers, T, X_pinn=None, X_semigroup=None, X_smooth=None, X_data=None, data=None,
        w_pinn=1., w_semigroup=1., w_smooth=1., w_data=1.
    ):
        super().__init__(
            layers, T, X_pinn, X_semigroup, X_smooth, X_data, data,
            w_pinn, w_semigroup, w_smooth, w_data
        )
    
    
    def _loss_pinn(self):
        """
        ODE system:
            dx/dt = 10x - 5xy
            dy/dt = 3y + xy - 3y^2
        """
        y = self.net_y(self.t_pinn, self.y_pinn)
        deriv = self.net_derivative(self.t_pinn, self.y_pinn)
        
        loss1 = torch.mean(
            (deriv[0] - 10 * y[:, 0:1] + 5 * y[:,0:1] * y[:,1:2])**2
        )
        loss2 = torch.mean(
            (deriv[1] - 3 * y[:,1:2] - y[:, 0:1] * y[:, 1:2] + 3 * y[:,1:2]**2) ** 2
        )
        loss = self.w_pinn * (loss1 + loss2)
        
        return loss

### Setup data example

In [3]:
ode_dimension = 2
layers = [ode_dimension + 1] + 6 * [64] + [ode_dimension]
T = 1
max_y0 = 5

# training samples
n_pinn = 10
t_pinn = np.random.uniform(0, T, (n_pinn, 1))
y_pinn = np.random.uniform(0, max_y0, (n_pinn, ode_dimension))
X_pinn = np.hstack([t_pinn, y_pinn])

n_semigroup = 10
st_semigroup = sample_points.uniform_triangle_2d(n_semigroup, T)
y_semigroup = np.random.uniform(0, max_y0, (n_semigroup, ode_dimension))
X_semigroup = np.hstack([st_semigroup, y_semigroup])

n_smooth = 10
t_smooth = np.random.uniform(0, T, (n_smooth, 1))
y_smooth = np.random.uniform(0, max_y0, (n_smooth, ode_dimension))
X_smooth = np.hstack([t_smooth, y_smooth])

In [4]:
model = ExampleNonlinear(layers, T, X_pinn=X_pinn, X_semigroup=X_semigroup, X_smooth=X_smooth)

In [5]:
%%time               
model.train()

iteration 100; loss: 0.8644, loss_pinn: 0.5062, loss_semigroup: 0.0583, loss_smooth: 0.3000
iteration 200; loss: 0.0686, loss_pinn: 0.0411, loss_semigroup: 0.0180, loss_smooth: 0.0095


KeyboardInterrupt: 

In [6]:
path = os.getcwd()

with open(f"{path}/model_example_nonlinear.pkl", "wb") as handle:
    pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f"{path}/model_example_nonlinear.pkl", "rb") as f:
    model = pickle.load(f)

## Predict and Plot the Solution

In [7]:
def rhs_example_nonlinear(t, r):
    """
    ODE system:
            dx/dt = 10x - 5xy
            dy/dt = 3y + xy - 3y^2
    """
    x, y = r
    dx_t = 10 * x - 5 * x * y
    dy_t = 3 * y + x * y - 3 * y**2
    
    return dx_t, dy_t


def get_solution(max_t, delta_t, init_val):
    
    times = np.linspace(0, max_t, int(max_t / delta_t) + 1)
    sol = solve_ivp(
        rhs_example_nonlinear, [0, float(max_t)], y0, t_eval=times,
        rtol=1e-10, atol=1e-10
    )
    return sol.y.T

In [8]:
y0 = np.array([2.0, 5.0])
max_t = 5
delta_t = 0.01
times = np.linspace(0, max_t, int(max_t / delta_t) + 1)

true_solution = get_solution(max_t, delta_t, y0)
tc_solution = model.predict_tc(max_t, delta_t, y0)

In [11]:
# Note that max_t in training is 1
y0 = np.random.uniform(0, max_y0, 2)
max_t = 5
delta_t = 0.01
times = np.linspace(0, max_t, int(max_t / delta_t) + 1)

true_solution = get_solution(max_t, delta_t, y0)
tc_solution = model.predict_tc(max_t, delta_t, y0)

ax = plot_solution(
    times, true_solution, 
    component_kwargs=[{'color': "black", 'label': "truth"}, {'color': "black"}]
)
ax = plot_solution(
    times, tc_solution, ax=ax,
    component_kwargs=[{'color': "blue", 'label': "tcPINN"}, {'color': "blue"}]
)
plt.legend()
plt.show()
plt.close()