In [1]:
import torch
from collections import OrderedDict

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
# matplotlib.rcParams.update({'font.family': 'serif', 'font.size': 10})
# matplotlib.rcParams['text.usetex'] = True
from matplotlib.lines import Line2D
import pickle

from scipy.interpolate import griddata
import time

np.random.seed(1234)

## Overview

This jupyter notebook implements the plain time-consistent physics-informed neural network (tcPINN) idea for the planar three-body problem. We have observed that during training, the tcPINN is unable to completely learn the dynamics. That is, the PINN loss does not converge to zero.

We hypothesize that this phenomenon is not caused by the ODE system being chaotic, but rather by the velocities of bodies converging to infinity during close encounters. This causes the loss function to have singularities (e.g. when $r_1(t) = r_2(t)$), which makes training significantly more difficult. Default single-precision floating-point format used by Pytorch might be insufficient when denominators in summands of the PINN loss become very small. Therefore, using double- or even higher precsion might still allow to learn the dynamics of the three-body problem with our proposed approach. 

Instead of focusing on floating-point precision during training, we tried to simplify the training task by exploiting the symmetry and scale invariance of the ODE system. This allowed us to decrease the domain of the initial values. When the barycenter of the three bodies is constant in time, the size of the ODE system can also be reduced from twelve to eight. Further details to these ideas are given in the notebooks $\texttt{three_body_problem_center_barycenter}$ and $\texttt{three_body_problem_center_barycenter_restricted_ivp.ipynb}$.

Consider three gravitationally interacting identical bodies with positions $r_i(t) \in \mathbb{R}^2$. Assuming a gravitational force of $G=1$, the Newtonian equations governing their motion reads

\begin{align*}
    \frac{d^2}{dt^2} \begin{pmatrix} r_1(t) \\ r_2(t) \\ r_3(t) \end{pmatrix} = \begin{pmatrix} - \frac{r_1(t) - r_2(t)}{|r_1(t) - r_2(t)|^3} - \frac{r_1(t) - r_3(t)}{|r_1(t) - r_3(t)|^3} \\ - \frac{r_2(t) - r_1(t)}{|r_2(t) - r_1(t)|^3} - \frac{r_2(t) - r_3(t)}{|r_2(t) - r_3(t)|^3} \\ - \frac{r_3(t) - r_1(t)}{|r_3(t) - r_1(t)|^3} - \frac{r_3(t) - r_2(t)}{|r_3(t) - r_2(t)|^3} \end{pmatrix}.
\end{align*}

This is a second-order ODE system of $6$ equations. By introducing the velocities $v_i(t) = \frac{d}{dt}r_i(t)$, it can be rewritten as the following first-order ODE system of $12$ equations:

\begin{align*}
    \frac{d}{dt} \begin{pmatrix} r_1 \\ r_2 \\ r_3 \\ v_1 \\ v_2 \\ v_3 \end{pmatrix} = \begin{pmatrix} v_1 \\ v_2 \\ v_3 \\ - \frac{r_1 - r_2}{|r_1 - r_2|^3} - \frac{r_1 - r_3}{|r_1 - r_3|^3} \\ - \frac{r_2 - r_1}{|r_2 - r_1|^3} - \frac{r_2 - r_3}{|r_2 - r_3|^3} \\ - \frac{r_3 - r_1}{|r_3 - r_1|^3} - \frac{r_3 - r_2}{|r_3 - r_2|^3} \end{pmatrix}.
\end{align*}

For completeness, all $12$ equations with written-out components are

\begin{align*}
    \frac{d}{dt} \begin{pmatrix} r_{11} \\ r_{12} \\ r_{21} \\ r_{22} \\ r_{31} \\ r_{32} \\ v_{11} \\ v_{12} \\ v_{21} \\ v_{22} \\ v_{31} \\ v_{32} \end{pmatrix} = \begin{pmatrix} 
    v_{11} \\ v_{12} \\ v_{21} \\ v_{22} \\ v_{31} \\ v_{32} \\ 
    - \frac{r_{11} - r_{21}}{|r_1 - r_2|^3} - \frac{r_{11} - r_{31}}{|r_1 - r_3|^3} \\
    - \frac{r_{12} - r_{22}}{|r_1 - r_2|^3} - \frac{r_{12} - r_{32}}{|r_1 - r_3|^3} \\
    - \frac{r_{21} - r_{11}}{|r_2 - r_1|^3} - \frac{r_{21} - r_{31}}{|r_2 - r_3|^3} \\
    - \frac{r_{22} - r_{12}}{|r_2 - r_1|^3} - \frac{r_{22} - r_{32}}{|r_2 - r_3|^3} \\ 
    - \frac{r_{31} - r_{11}}{|r_3 - r_1|^3} - \frac{r_{31} - r_{21}}{|r_3 - r_2|^3} \\
    - \frac{r_{32} - r_{12}}{|r_3 - r_1|^3} - \frac{r_{32} - r_{22}}{|r_3 - r_2|^3}
\end{pmatrix}.
\end{align*}

With the notation $y = (r_{11}, r_{12}, r_{21}, r_{22}, r_{31}, r_{32}, v_{11}, v_{12}, v_{21}, v_{22}, v_{31}, v_{32})$ used in the implementation, the system reads

\begin{align*}
    \frac{d}{dt} y = \begin{pmatrix} 
    y_6 \\ y_7 \\ y_8 \\ y_9 \\ y_{10} \\ y_{11} \\ 
    - \frac{y_0 - y_2}{((y_0 - y_2)^2 + (y_1 - y_3)^2)^{3/2}} - \frac{y_0 - y_4}{((y_0 - y_4)^2 + (y_1 - y_5)^2)^{3/2}} \\
    - \frac{y_1 - y_3}{((y_0 - y_2)^2 + (y_1 - y_3)^2)^{3/2}} - \frac{y_1 - y_5}{((y_0 - y_4)^2 + (y_1 - y_5)^2)^{3/2}} \\
    - \frac{y_2 - y_0}{((y_0 - y_2)^2 + (y_1 - y_3)^2)^{3/2}} - \frac{y_2 - y_4}{((y_2 - y_4)^2 + (y_3 - y_5)^2)^{3/2}} \\
    - \frac{y_3 - y_1}{((y_0 - y_2)^2 + (y_1 - y_3)^2)^{3/2}} - \frac{y_3 - y_5}{((y_2 - y_4)^2 + (y_3 - y_5)^2)^{3/2}} \\ 
    - \frac{y_4 - y_0}{((y_0 - y_4)^2 + (y_1 - y_5)^2)^{3/2}} - \frac{y_4 - y_2}{((y_2 - y_4)^2 + (y_3 - y_5)^2)^{3/2}} \\
    - \frac{y_5 - y_1}{((y_0 - y_4)^2 + (y_1 - y_5)^2)^{3/2}} - \frac{y_5 - y_3}{((y_2 - y_4)^2 + (y_3 - y_5)^2)^{3/2}}
\end{pmatrix}.
\end{align*}

In [2]:
# CUDA support 
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [60]:
# the deep neural network
class DNN(torch.nn.Module):
    
    def __init__(self, layers):
        
        super().__init__()
        
        # parameters
        self.depth = len(layers) - 1
        
        # set up layer order dict
        self.activation = torch.nn.Tanh
        
        layer_list = list()
        for i in range(self.depth - 1): 
            layer_list.append(
                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))
            )
            layer_list.append(('activation_%d' % i, self.activation()))
            
        layer_list.append(
            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
        )
        layerDict = OrderedDict(layer_list)
        
        # deploy layers
        self.layers = torch.nn.Sequential(layerDict)
    
    
    def forward(self, x):
        
        # x = (t, y0)
        return self.layers(x)

In [61]:
# tcPINN: time-consistent physics-informed neural network
class TcPINN():

    def __init__(self, X_pinn, X_semigroup, X_smooth, layers, T):

        # neural network architecture
        self.layers = layers
        self.dnn = DNN(layers).to(device)
        
        # semigroup PINN step time
        self.T = torch.tensor(T).float().to(device)

        # training data
        self.t_pinn = torch.tensor(X_pinn[:, :1], requires_grad=True).float().to(device)
        self.y_pinn = torch.tensor(X_pinn[:, 1:], requires_grad=True).float().to(device)
        
        self.s_semigroup = torch.tensor(X_semigroup[:, :1], requires_grad=True).float().to(device)
        self.t_semigroup = torch.tensor(X_semigroup[:, 1:2], requires_grad=True).float().to(device)
        self.y_semigroup = torch.tensor(X_semigroup[:, 2:], requires_grad=True).float().to(device)
        
        self.t_smooth = torch.tensor(X_smooth[:, :1], requires_grad=True).float().to(device)
        self.y_smooth = torch.tensor(X_smooth[:, 1:], requires_grad=True).float().to(device)
        
        # optimization
        self.optimizer = torch.optim.LBFGS(
            self.dnn.parameters(), lr=1.0, max_iter=50000, max_eval=50000, 
            history_size=50, tolerance_grad=1e-5, tolerance_change=np.finfo(float).eps, 
            line_search_fn="strong_wolfe"
        )

        self.iter = 0
    
    
    def net_y(self, t, y0):
        
        # The M(t, y0) = y0 + t N(t, y0) scheme seems to drastically increase the accuracy
        # This works perfectly fine with automatic differentiation
        y = y0 + t * self.dnn(torch.cat([t, y0], dim=1))
        
        return y
    
    
    def net_derivative(self, t, y0):
        """
        Pytorch automatic differentiation to compute the derivative of the neural network
        """
        y = self.net_y(t, y0)
        
        # vectors for the autograd vector Jacobian product 
        # to compute the derivatives w.r.t. every output dimension
        vectors = [torch.zeros_like(y) for _ in range(12)]
        
        for i, vec in enumerate(vectors):
            
            vec[:,i] = 1.
        
        # list of derivative tensors
        # the first entry is a tensor with \partial_t PINN_0(t, y0) for all (t, y0) in the batch,
        # each input (t, y0) corresponds to one row in each tensor
        derivatives = [
            torch.autograd.grad(
                y, t, 
                grad_outputs=vec,
                retain_graph=True,
                create_graph=True
            )[0]
            for vec in vectors
        ]
        
        return derivatives
    
    
    def loss_function(self):
        
        self.optimizer.zero_grad()
        
        y_pred = self.net_y(self.t_pinn, self.y_pinn)
        deriv_pred = self.net_derivative(self.t_pinn, self.y_pinn)
        
        # This is specific to the ODE
        loss_pinn0 = torch.mean((deriv_pred[0] - y_pred[:,6:7]) ** 2)
        loss_pinn1 = torch.mean((deriv_pred[1] - y_pred[:,7:8]) ** 2)
        loss_pinn2 = torch.mean((deriv_pred[2] - y_pred[:,8:9]) ** 2)
        loss_pinn3 = torch.mean((deriv_pred[3] - y_pred[:,9:10]) ** 2)
        loss_pinn4 = torch.mean((deriv_pred[4] - y_pred[:,10:11]) ** 2)
        loss_pinn5 = torch.mean((deriv_pred[5] - y_pred[:,11:12]) ** 2)
        
        loss_pinn6 = torch.mean((deriv_pred[6] + (y_pred[:,0:1] - y_pred[:,2:3]) / ((y_pred[:,0:1] - y_pred[:,2:3])**2 + (y_pred[:,1:2] - y_pred[:,3:4])**2)**(3/2) + (y_pred[:,0:1] - y_pred[:,4:5]) / ((y_pred[:,0:1] - y_pred[:,4:5])**2 + (y_pred[:,1:2] - y_pred[:,5:6])**2)**(3/2)) ** 2)
        loss_pinn7 = torch.mean((deriv_pred[7] + (y_pred[:,1:2] - y_pred[:,3:4]) / ((y_pred[:,0:1] - y_pred[:,2:3])**2 + (y_pred[:,1:2] - y_pred[:,3:4])**2)**(3/2) + (y_pred[:,1:2] - y_pred[:,5:6]) / ((y_pred[:,0:1] - y_pred[:,4:5])**2 + (y_pred[:,1:2] - y_pred[:,5:6])**2)**(3/2)) ** 2)
        loss_pinn8 = torch.mean((deriv_pred[8] + (y_pred[:,2:3] - y_pred[:,0:1]) / ((y_pred[:,0:1] - y_pred[:,2:3])**2 + (y_pred[:,1:2] - y_pred[:,3:4])**2)**(3/2) + (y_pred[:,2:3] - y_pred[:,4:5]) / ((y_pred[:,2:3] - y_pred[:,4:5])**2 + (y_pred[:,3:4] - y_pred[:,5:6])**2)**(3/2)) ** 2)
        loss_pinn9 = torch.mean((deriv_pred[9] + (y_pred[:,3:4] - y_pred[:,1:2]) / ((y_pred[:,0:1] - y_pred[:,2:3])**2 + (y_pred[:,1:2] - y_pred[:,3:4])**2)**(3/2) + (y_pred[:,3:4] - y_pred[:,5:6]) / ((y_pred[:,2:3] - y_pred[:,4:5])**2 + (y_pred[:,3:4] - y_pred[:,5:6])**2)**(3/2)) ** 2)
        loss_pinn10 = torch.mean((deriv_pred[10] + (y_pred[:,4:5] - y_pred[:,0:1]) / ((y_pred[:,0:1] - y_pred[:,4:5])**2 + (y_pred[:,1:2] - y_pred[:,5:6])**2)**(3/2) + (y_pred[:,4:5] - y_pred[:,2:3]) / ((y_pred[:,2:3] - y_pred[:,4:5])**2 + (y_pred[:,3:4] - y_pred[:,5:6])**2)**(3/2)) ** 2)
        loss_pinn11 = torch.mean((deriv_pred[11] + (y_pred[:,5:6] - y_pred[:,1:2]) / ((y_pred[:,0:1] - y_pred[:,4:5])**2 + (y_pred[:,1:2] - y_pred[:,5:6])**2)**(3/2) + (y_pred[:,5:6] - y_pred[:,3:4]) / ((y_pred[:,2:3] - y_pred[:,4:5])**2 + (y_pred[:,3:4] - y_pred[:,5:6])**2)**(3/2)) ** 2)

        loss_pinn = loss_pinn0 + loss_pinn1 + loss_pinn2 + loss_pinn3 + loss_pinn4 + loss_pinn5 + loss_pinn6 + loss_pinn7 + loss_pinn8 + loss_pinn9 + loss_pinn10 + loss_pinn11
        
        # The general semigroup loss for autonomous ODEs
        y_pred_tps = self.net_y(self.s_semigroup + self.t_semigroup, self.y_semigroup)
        y_pred_s = self.net_y(self.s_semigroup, self.y_semigroup)
        y_pred_restart = self.net_y(self.t_semigroup, y_pred_s)
        loss_semigroup = torch.mean((y_pred_tps - y_pred_restart) ** 2)
        
        # The general smoothness loss
        y_pred_smooth = self.net_y(self.t_smooth, self.y_smooth)
        deriv_pred_below = self.net_derivative(self.t_smooth, self.y_smooth)
        deriv_pred_above = self.net_derivative(torch.zeros_like(self.t_smooth, requires_grad=True), y_pred_smooth)
        
        loss_smooth = .0
        
        for t1, t2 in zip(deriv_pred_below, deriv_pred_above):
            
            loss_smooth += torch.mean((t1 - t2) ** 2)
        
        loss = loss_pinn + loss_smooth + loss_semigroup
        
        loss.backward()
        self.iter += 1
        
        if self.iter % 10 == 0:
            print(
                f"Iter {self.iter}, Loss: {loss.item():.5f}, Loss_pinn: {loss_pinn.item():.5f} " \
                f"Loss_smooth: {loss_smooth.item():.5f}, Loss_semigroup: {loss_semigroup.item():.5f}"
            )
        
        if self.iter % 100 == 0:
            
            with open(f"./model_iter{self.iter}.pkl", "wb") as handle:
                pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        return loss
    
    
    def train(self):
        
        self.dnn.train()
        self.optimizer.step(self.loss_function)
    
    
    def predict(self, t, y0):
        
        t = torch.tensor(t, requires_grad=True).float().to(device)
        y0 = torch.tensor(y0, requires_grad=True).float().to(device)
        
        self.dnn.eval()
        y = self.net_y(t, y0)
        y = y.detach().cpu().numpy()
        
        return y

### Setup Training Data

In [62]:
layers = [13] + 10 * [64] + [12]


T = 2
max_r = 2.
max_v = 0.25

# standard PINN loss function training samples
N_pinn = 10000
N_semigroup = 10000
N_smooth = 10000


def sample_y(max_r, max_v, N):

    r = np.random.uniform(-max_r, max_r, (N, 6))
    v = np.random.uniform(-max_v, max_v, (N, 6))
    
    return np.hstack([r, v])


t_pinn = np.random.uniform(0, T, (N_pinn, 1))
y_pinn = sample_y(max_r, max_v, N_pinn)
X_pinn = np.hstack([t_pinn, y_pinn])


# uniformly sample s, t with s+t \leq T
r1 = np.random.uniform(0, 1, N_semigroup)
r2 = np.random.uniform(0, 1, N_semigroup)
s_semigroup, t_semigroup = np.sqrt(r1) * (1 - r2), r2 * np.sqrt(r1)
s_semigroup, t_semigroup = T * s_semigroup[:, np.newaxis], T * t_semigroup[:, np.newaxis]
y_semigroup = sample_y(max_r, max_v, N_semigroup)
X_semigroup = np.hstack([s_semigroup, t_semigroup, y_semigroup])


t_smooth = np.random.uniform(0, T, (N_smooth, 1))
y_smooth = sample_y(max_r, max_v, N_smooth)
X_smooth = np.hstack([t_smooth, y_smooth])


with open("./X_pinn_naive.pkl", "wb") as handle:
    pickle.dump(X_pinn, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open("./X_semigroup_naive.pkl", "wb") as handle:
    pickle.dump(X_semigroup, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open("./X_smooth_naive.pkl", "wb") as handle:
    pickle.dump(X_smooth, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [63]:
model = TcPINN(X_pinn, X_semigroup, X_smooth, layers, T)

In [64]:
%%time
               
model.train()

Iter 10, Loss: 9436.41406, Loss_pinn: 9436.41406 Loss_smooth: 0.00001, Loss_semigroup: 0.00000
Iter 20, Loss: 8380.43164, Loss_pinn: 8380.43164 Loss_smooth: 0.00001, Loss_semigroup: 0.00000
Iter 30, Loss: 3440.80420, Loss_pinn: 3440.80420 Loss_smooth: 0.00002, Loss_semigroup: 0.00000
Iter 40, Loss: 2640.09766, Loss_pinn: 2640.09766 Loss_smooth: 0.00006, Loss_semigroup: 0.00000
Iter 50, Loss: 1950.46606, Loss_pinn: 1950.46594 Loss_smooth: 0.00009, Loss_semigroup: 0.00000
Iter 60, Loss: 1453.66809, Loss_pinn: 1453.66797 Loss_smooth: 0.00012, Loss_semigroup: 0.00000
Iter 70, Loss: 1027.11914, Loss_pinn: 1027.11890 Loss_smooth: 0.00026, Loss_semigroup: 0.00000
Iter 80, Loss: 812.82050, Loss_pinn: 812.82007 Loss_smooth: 0.00042, Loss_semigroup: 0.00001
Iter 90, Loss: 672.88214, Loss_pinn: 672.88116 Loss_smooth: 0.00097, Loss_semigroup: 0.00001
Iter 100, Loss: 530.15588, Loss_pinn: 530.15082 Loss_smooth: 0.00499, Loss_semigroup: 0.00006
Iter 110, Loss: 413.55954, Loss_pinn: 413.54645 Loss_sm

Iter 910, Loss: 2.24990, Loss_pinn: 2.10371 Loss_smooth: 0.14201, Loss_semigroup: 0.00419
Iter 920, Loss: 2.23374, Loss_pinn: 2.08447 Loss_smooth: 0.14518, Loss_semigroup: 0.00409
Iter 930, Loss: 2.20799, Loss_pinn: 2.06762 Loss_smooth: 0.13630, Loss_semigroup: 0.00406
Iter 940, Loss: 2.18804, Loss_pinn: 2.04557 Loss_smooth: 0.13830, Loss_semigroup: 0.00418
Iter 950, Loss: 2.16486, Loss_pinn: 2.02391 Loss_smooth: 0.13690, Loss_semigroup: 0.00405
Iter 960, Loss: 2.14405, Loss_pinn: 2.00420 Loss_smooth: 0.13582, Loss_semigroup: 0.00403
Iter 970, Loss: 2.12817, Loss_pinn: 1.99320 Loss_smooth: 0.13102, Loss_semigroup: 0.00396
Iter 980, Loss: 2.11092, Loss_pinn: 1.97756 Loss_smooth: 0.12949, Loss_semigroup: 0.00387
Iter 990, Loss: 2.08783, Loss_pinn: 1.95575 Loss_smooth: 0.12813, Loss_semigroup: 0.00395
Iter 1000, Loss: 2.07039, Loss_pinn: 1.93335 Loss_smooth: 0.13315, Loss_semigroup: 0.00388
Iter 1010, Loss: 2.05360, Loss_pinn: 1.92139 Loss_smooth: 0.12848, Loss_semigroup: 0.00373
Iter 102

Iter 1820, Loss: 1.52310, Loss_pinn: 1.43440 Loss_smooth: 0.08569, Loss_semigroup: 0.00300
Iter 1830, Loss: 1.52029, Loss_pinn: 1.43119 Loss_smooth: 0.08612, Loss_semigroup: 0.00298
Iter 1840, Loss: 1.51702, Loss_pinn: 1.42858 Loss_smooth: 0.08546, Loss_semigroup: 0.00298
Iter 1850, Loss: 1.51301, Loss_pinn: 1.42568 Loss_smooth: 0.08439, Loss_semigroup: 0.00294
Iter 1860, Loss: 1.50828, Loss_pinn: 1.42120 Loss_smooth: 0.08417, Loss_semigroup: 0.00291
Iter 1870, Loss: 1.50522, Loss_pinn: 1.41834 Loss_smooth: 0.08401, Loss_semigroup: 0.00288
Iter 1880, Loss: 1.50144, Loss_pinn: 1.41630 Loss_smooth: 0.08224, Loss_semigroup: 0.00289
Iter 1890, Loss: 1.49823, Loss_pinn: 1.41277 Loss_smooth: 0.08257, Loss_semigroup: 0.00289
Iter 1900, Loss: 1.49567, Loss_pinn: 1.41071 Loss_smooth: 0.08206, Loss_semigroup: 0.00290
Iter 1910, Loss: 1.49312, Loss_pinn: 1.40796 Loss_smooth: 0.08228, Loss_semigroup: 0.00289
Iter 1920, Loss: 1.49317, Loss_pinn: 1.40830 Loss_smooth: 0.08196, Loss_semigroup: 0.00291

Iter 2730, Loss: 1.29262, Loss_pinn: 1.20918 Loss_smooth: 0.08034, Loss_semigroup: 0.00309
Iter 2740, Loss: 1.29105, Loss_pinn: 1.20744 Loss_smooth: 0.08051, Loss_semigroup: 0.00310
Iter 2750, Loss: 1.28971, Loss_pinn: 1.20561 Loss_smooth: 0.08100, Loss_semigroup: 0.00309
Iter 2760, Loss: 1.28817, Loss_pinn: 1.20399 Loss_smooth: 0.08107, Loss_semigroup: 0.00310
Iter 2770, Loss: 1.28691, Loss_pinn: 1.20246 Loss_smooth: 0.08134, Loss_semigroup: 0.00312
Iter 2780, Loss: 1.28567, Loss_pinn: 1.20147 Loss_smooth: 0.08108, Loss_semigroup: 0.00313
Iter 2790, Loss: 1.28392, Loss_pinn: 1.19972 Loss_smooth: 0.08106, Loss_semigroup: 0.00314
Iter 2800, Loss: 1.28271, Loss_pinn: 1.19851 Loss_smooth: 0.08105, Loss_semigroup: 0.00314
Iter 2810, Loss: 1.28058, Loss_pinn: 1.19684 Loss_smooth: 0.08055, Loss_semigroup: 0.00319
Iter 2820, Loss: 1.27865, Loss_pinn: 1.19514 Loss_smooth: 0.08032, Loss_semigroup: 0.00319
Iter 2830, Loss: 1.27728, Loss_pinn: 1.19352 Loss_smooth: 0.08057, Loss_semigroup: 0.00319

Iter 3640, Loss: 1.18907, Loss_pinn: 1.11156 Loss_smooth: 0.07442, Loss_semigroup: 0.00309
Iter 3650, Loss: 1.18829, Loss_pinn: 1.11099 Loss_smooth: 0.07422, Loss_semigroup: 0.00308
Iter 3660, Loss: 1.18743, Loss_pinn: 1.11007 Loss_smooth: 0.07430, Loss_semigroup: 0.00306
Iter 3670, Loss: 1.18604, Loss_pinn: 1.10922 Loss_smooth: 0.07375, Loss_semigroup: 0.00307
Iter 3680, Loss: 1.18521, Loss_pinn: 1.10850 Loss_smooth: 0.07361, Loss_semigroup: 0.00310
Iter 3690, Loss: 1.18458, Loss_pinn: 1.10722 Loss_smooth: 0.07424, Loss_semigroup: 0.00311
Iter 3700, Loss: 1.18355, Loss_pinn: 1.10666 Loss_smooth: 0.07378, Loss_semigroup: 0.00311
Iter 3710, Loss: 1.18270, Loss_pinn: 1.10568 Loss_smooth: 0.07390, Loss_semigroup: 0.00312
Iter 3720, Loss: 1.18169, Loss_pinn: 1.10513 Loss_smooth: 0.07344, Loss_semigroup: 0.00312
Iter 3730, Loss: 1.18104, Loss_pinn: 1.10430 Loss_smooth: 0.07360, Loss_semigroup: 0.00314


KeyboardInterrupt: 

In [65]:
with open("./model.pkl", "wb") as handle:
    pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [71]:
with open("./model.pkl", "rb") as f:
    model = pickle.load(f)

## Predict and Plot the Solution

In [72]:
def generate_figure(figsize, xlim, ylim):
    
    fig, ax = plt.subplots(figsize=figsize)
    ax.spines[['top', 'right']].set_visible(False)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    
    return fig, ax


def plot_ode_solution(ax, y, index0, index1, *args, **kwargs):
    
    ax.plot(y[:,index0], y[:,index1], '.-', *args, **kwargs)
    
    return ax

In [73]:
def predict_tc(model, y0, max_t_pred, delta_t):
    """
    detla_t should devide model.max_t to guarantee equidistant steps
    """
    times = np.arange(0, model.T + delta_t, delta_t)[1:]
    times = times[:,np.newaxis]
    n_resets = int(np.ceil(max_t_pred / model.T))
    
    trajectory = np.array([y0])
    
    for _ in range(n_resets):
        
        y0 = trajectory[-1]
        y0 = np.array([y0 for _ in range(len(times))])
        segment =  model.predict(times, y0)
        trajectory = np.vstack([trajectory, segment])
    
    return trajectory

In [74]:
# Note that max_t in training is 1
y0 = [1., 0., 0., 1., -1., -1, .0, .0, .0, .0, .0, .0]
max_t_pred = 10.
delta_t = 0.01

validation_tc = predict_tc(model, y0, max_t_pred, delta_t)

In [1]:
fig, ax = generate_figure(figsize=(8,8), xlim=[-7, 7], ylim=[-7, 7])

ax = plot_ode_solution(ax, validation_tc, 0, 1, markevery=[0], label="Body 1", color="#03468F")
ax = plot_ode_solution(ax, validation_tc, 2, 3, markevery=[0], label="Body 2", color="#A51C30")
ax = plot_ode_solution(ax, validation_tc, 4, 5, markevery=[0], label="Body 3", color="orange")

plt.legend()
plt.savefig("3_body_problem.pdf", bbox_inches="tight")
plt.show()