In [None]:
import sys
import os

current_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(os.path.dirname(current_dir), '.'))
if project_root not in sys.path:
    sys.path.append(project_root)

import pinns

# For cleaner output.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

We will solve the folloving problem: given $x_0$ and $v_0$, find such $x(t): \mathbb{R} \to \mathbb{R}$, so that

$$\frac{\mathrm{d}^2 x}{\mathrm{d}t^2} + \zeta \omega_{0} \frac{\mathrm{d}x}{\mathrm{d}t} + \omega_{0}^{2}x = 0$$
$$x(0) = x_{0}, \frac{\mathrm{d}x}{\mathrm{d}t}(0) = v_{0}$$

Parameters $\zeta$ and $\omega_{0}$ are a physical parameters that characterize oscillation and it's damping.

We consider problem with given $\zeta = 0.2$, $\omega_{0} = 2$ and $x_0 = 5, v_0 = 7$ on a domain $[0, 10]$.

In [None]:
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

We know analytical solution, so we can measure actual error of our model.

In [None]:
def analytical(t):
    return torch.exp(-0.4*t)*(4.59*torch.sin(1.96*t) + 5*torch.cos(1.96*t))

# We fix variables for clarity.
T = 10
zeta, omega = 0.2, 2.0
x0, v0 = 5.0, 7.0

t = torch.linspace(0, T, 128)
solution = analytical(t)

fig = plt.figure(figsize=(5, 3))
plt.plot(t, solution)
plt.grid()
plt.show()

Now we need to define all basic building blocks that will be used to train our model.

First of all, we need to build all of ours samplers. We can use predefined samplers if their signature satisfies all of our needs. Since initial values are just numbers and not functions or large arrays, it is reasonable to use ConstantSampler logic. For collocation points, we might choose RandomSampler. For test points, let's use ConstantSampler.

If you want, for example, load data from disk or sample it from some function, or make signatures easier, you should define your own sampler. Also you can always redefine training logic and make your own Trainer, if it is really necessary (and it should not be so hard).

In [None]:
from pinns.samplers import ConstantSampler, RandomRectangularSampler

# Constraints sampler must return tuple of tensors 
# (points, values), each of shape [num_pts, coords].
constraints_sampler = ConstantSampler((
    torch.tensor([[0.]], requires_grad = True),
    torch.tensor([x0, v0])
))

# Collocation sampler must return just tensor of shape [num_pts, coords].
domain = {'t': [0, T]}
collocation_sampler = RandomRectangularSampler(domain, 256, return_dict=False)

# Test points sampler output must have same
# structure as constraints sampler.
test_points_sampler = ConstantSampler((t.view(-1, 1), solution.view(-1, 1)))

Now we want to define loss function. For differentiation, let's use function d from our derivatives collection.

Remember that internal logic of training must be consistent with sampler output and way in what model makes predictions.

In [None]:
from pinns.derivatives import Derivative

d = Derivative(method = 'autograd')

def loss(
    cstr_pts, cstr_pred, cstr_vals,
    coll_pts, coll_pred,
    zeta = 0.2, omega = 2.0
    ):
    
    def init_loss(x0, t0):
        v0 = d(x0, t0)
        return torch.mean(torch.square(torch.hstack([x0, v0]) - cstr_vals))

    def ode_loss(x, t):
        v, a = d(x, t, orders = [1, 2])
        return torch.mean(torch.square(a + 2 * zeta * omega * v + omega**2 * x))
    
    losses = (
        init_loss(cstr_pred, cstr_pts),
        ode_loss(coll_pred, coll_pts)
    )
    
    return losses

And now we can define some neural network and train it using default training logic.

In [None]:
from pinns import Trainer
from pinns.models import FF
from pinns.optimizers import Adam

pinn = FF([1] + [64] + [1], activ=nn.Tanh(), biases=True)
print(f'Model has {pinn.count_parameters()} trainable parameters.')

adam = Adam(pinn, lr = 1e-2)

trainer = Trainer(
    loss,
    pinn,
    constraints_sampler,
    collocation_sampler,
    loss_coefs=[0.8, 0.2],    # Coefficients are very important.
    test_points_sampler=test_points_sampler
)

trainer.train(
    num_iters=1000,
    optimizers=[(0, adam)],
    validate_every=1
    )

In [None]:
from pinns.errors import l2
print(f'L2 error of model is {trainer.evaluate(l2):.5f}')

It depends on particular case, whether it is good value or not.

In [None]:
# pinn.model = torch.load('./very_good_model_dont_delete.pt')

fig, axs = plt.subplots(1, 2, figsize=(10, 3))

axs[0].plot(trainer.loss_history, label='Loss')
axs[0].plot(range(0, trainer.iter + 1, 1), trainer.error_history, label='L2')
axs[0].grid()
axs[0].set_yscale('log')
axs[0].legend()

preds = pinn.predict(t.reshape(-1, 1))
axs[1].plot(t, solution, label='Solution')
axs[1].plot(t, preds.detach(), label='Predicts', linestyle=':')
axs[1].grid()
axs[1].legend()

plt.show()