In [12]:
import sys
import os

current_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(os.path.dirname(current_dir), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

import pinns

# For cleaner output.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [13]:
import torch
import torch.nn as nn

import numpy as np 

import matplotlib.pyplot as plt

import time

from pinns import Trainer
from pinns.models import FF, KAN
from pinns.errors import l2
from pinns.optimizers import Adam
from pinns.derivatives import Derivative
from pinns.samplers import RandomSampler, ConstantSampler, DataSampler

from prettytable import PrettyTable

In [23]:
gridits = [
    [(2, 500)],
    [(5, 500)],
    [(10, 500)],
    [(20, 500)],
    [(2, 150), (5, 350)],
    [(2, 150), (20, 350)],
    [(5, 150), (10, 350)],
    [(5, 150), (20, 350)],
    [(10, 150), (20, 350)],
    [(2, 250), (5, 250)],
    [(2, 250), (10, 250)],
    [(2, 250), (20, 250)],
    [(5, 250), (20, 250)],
    [(10, 250), (20, 250)],
    [(2, 350), (5, 350)],
    [(5, 350), (20, 150)],
    [(2, 450), (10, 50)],
]

len(gridits)

17

---
## Damped Harmonic Oscillator

In [15]:
from scipy.integrate import solve_ivp

def dho(t, y, zeta, omega):
    x, v = y
    dxdt = v
    dvdt = -2 * zeta * omega * v - omega**2 * x
    return [dxdt, dvdt]

T = 10              # We need to fix some computational domain.
p = (0.2, 2.0)      # Parameters that yields interesting function.
x0, v0 = 5.0, 7.0   # Initial values may be arbitrary.

t = torch.linspace(0, T, 128)
solution = solve_ivp(dho, (0, T), (x0, v0), t_eval=t, args=p).y[0]

In [16]:
from pinns.samplers import ConstantSampler, RandomSampler

# Constraints (training data) sampler must return tuple (points, values).
dho_constraints_sampler = ConstantSampler((
    torch.tensor([[0.]], requires_grad=True),
    torch.tensor([x0, v0])
))

# Collocation sampler must return just tensor of shape [num_pts, coords].
dho_domain = {'t': [0, T]}
dho_collocation_sampler = RandomSampler(dho_domain, 256, return_dict=False)

# Test points sampler must return tuple (points, values) of shape [num_pts, coords].
dho_test_points_sampler = ConstantSampler(
    (t.reshape(-1, 1), solution.reshape(-1, 1))
    )

In [17]:
d = Derivative(method = 'autograd')

def dho_loss(
    cstr_pts, cstr_pred, cstr_vals,
    coll_pts, coll_pred,
    zeta = 0.2, omega = 2.0
    ):
    
    def init_loss(x0, t0):
        v0 = d(x0, t0)
        return torch.mean(torch.square(torch.hstack([x0, v0]) - cstr_vals))

    def ode_loss(x, t):
        v, a = d(x, t, orders = [1, 2])
        return torch.mean(torch.square(a + 2 * zeta * omega * v + omega**2 * x))
    
    losses = (
        init_loss(cstr_pred, cstr_pts),
        ode_loss(coll_pred, coll_pts)
    )
    
    return losses

In [18]:
def train_dho(gridit, lr = 1e-2, loss_coefs = [0.8, 0.2], plot = True):
    
    pinn = KAN([1, 5, 1], grid = 5)
    
    trainer = Trainer(
        dho_loss,
        pinn,
        dho_constraints_sampler,
        dho_collocation_sampler,
        loss_coefs=loss_coefs,    # Coefficients are very important.
        test_points_sampler=dho_test_points_sampler
    )
    
    for i, (grid, it) in enumerate(gridit):
    
        pts, _ = dho_test_points_sampler()

        pinn = KAN([1, 5, 1], grid = grid).initialize_from_another_model(trainer.model, pts)
        trainer.model = pinn
            
        adam = Adam(pinn, lr = lr)
        
        # trainer.iter += 1
        trainer.train(
            num_iters = it,
            optimizers=[(0, adam)],
            validate_every=1,
            show_progress=True
        )
    
    error = trainer.evaluate(l2)
    # print(f'Error is {error:.5f}')
    
    if plot:
        fig, axs = plt.subplots(1, 2, figsize=(10, 3))

        axs[0].plot(trainer.loss_history, label='Loss')
        axs[0].plot(trainer.error_history, label='L2')
        axs[0].grid()
        axs[0].set_yscale('log')
        axs[0].legend()

        preds = pinn.predict(t.reshape(-1, 1))
        axs[1].plot(t, solution, label='Solution')
        axs[1].plot(t, preds.detach(), label='Predicts', linestyle=':')
        axs[1].grid()
        axs[1].legend()

        plt.show()
    
    return error, trainer.loss_history

---
## Diffusion 1D

In [19]:
path = './examples/diff1d/'

def get_data(path):
    init = torch.tensor(np.load(path + 'init_data.npy'))
    left = torch.tensor(np.load(path + 'left_data.npy'))
    right = torch.tensor(np.load(path + 'right_data.npy'))
    
    return ([init[:,  :2], left[:,  :2], right[:,  :2]], 
            [init[:, [2]], left[:, [2]], right[:, [2]]])

diff_constraints_sampler = ConstantSampler(get_data(path))

diff_domain = {'t': [0, 0.5], 'x': [0, 1]}
diff_collocation_sampler = RandomSampler(diff_domain, 2048, return_dict=True)

diff_test_points_sampler = DataSampler(path + 'solution.npy', 1024, 2)

In [20]:
d = Derivative(method='autograd')

def diff_loss(
    cstr_pts, cstr_pred, cstr_vals,
    coll_pts, coll_pred,
    D = 0.5
    ):
    
    # We do not need left and right because we are solving 
    # Dirichlet problem and we just compare predictions and 
    # solution. If we solve Cauchy or Robin problem, we 
    # need to calculate derivatives at boundary points.
    
    init_pts, left_pts, right_pts = cstr_pts
    init_pred, left_pred, right_pred = cstr_pred
    init_vals, left_vals, right_vals = cstr_vals
    
    # If we use RandomSampler, we are unable to split values
    # like that. But it is just for clarity.
    
    t, x = coll_pts['t'], coll_pts['x']
    
    def initial_loss():
        return torch.mean(torch.square(init_pred - init_vals))
    
    def left_loss():
        return torch.mean(torch.square(left_pred - left_vals))
    
    def right_loss():
        return torch.mean(torch.square(right_pred - right_vals))
    
    def pde_loss(u, t, x):
        ut  = d(u,  t)
        uxx = d(u,  x, orders = 2)
        
        return torch.mean(torch.square(ut - D * uxx))
    
    return (
        initial_loss(), 
        left_loss(), 
        right_loss(), 
        pde_loss(coll_pred, t, x)
    )

In [21]:
Nt, Nx = 500, 750
pts, values = diff_test_points_sampler(full=True)

pts = [
    pts[:,0].reshape(Nx, Nt),
    pts[:,1].reshape(Nx, Nt)
]
values = values.reshape(Nx, Nt)

cstr_pts, cstr_vals = diff_constraints_sampler()
stacked_pts = torch.cat([torch.hstack([t[:, [1]], t[:, [0]]]) for t in cstr_pts])
stacked_vals = torch.cat(cstr_vals)
constraints = torch.hstack([stacked_pts, stacked_vals.reshape(-1, 1)]).T

def train_diff(gridit, lr = 1e-2, loss_coefs = [0.75, 0.75, 0.75, 0.25], plot = True):

    pinn = KAN([2, 10, 1], grid = 5)
    
    trainer = Trainer(
        diff_loss,
        pinn,
        diff_constraints_sampler,
        diff_collocation_sampler,
        loss_coefs=loss_coefs,    # Coefficients are very important.
        test_points_sampler=diff_test_points_sampler
    )
    
    for i, (grid, it) in enumerate(gridit):
    
        points, _ = diff_test_points_sampler()

        pinn = KAN([2, 10, 1], grid = grid).initialize_from_another_model(trainer.model, points)
        trainer.model = pinn
            
        adam = Adam(pinn, lr = lr)
        
        # trainer.iter += 1
        trainer.train(
            num_iters = it,
            optimizers=[(0, adam)]
        )
        
    error = trainer.evaluate(l2, full = True)
    # print(f'Error is {error:.5f}')
    
    if plot:
        fig = plt.figure(figsize=(12, 5))

        ax = fig.add_subplot(121)
        ax.plot(trainer.loss_history, label='Loss')
        # ax.plot(range(0, trainer.iter + 1, 1), trainer.error_history, label='Error')
        ax.grid()
        ax.set_yscale('log')
        ax.legend()

        preds = pinn.predict(diff_test_points_sampler(full=True)[0]).detach().reshape(Nx, Nt)

        ax = fig.add_subplot(122, projection='3d')
        ax.plot_surface(pts[1], pts[0], preds, cmap='viridis')
        # ax.plot_surface(pts[1], pts[0], values, cmap='viridis')

        ax.scatter3D(*constraints, color='r', s=10)
        # ax.view_init(80, -120)

        plt.tight_layout()
        plt.show()
    
    return error, trainer.loss_history

In [24]:
table = PrettyTable(['Name', 'DHO Error', 'DIFF Error'])

for gridit in gridits:
    dho_error, _ = train_dho(gridit, lr = 1e-2, plot = False)
    diff_error, _ = train_diff(gridit, lr = 1e-2, plot = False)
    table.add_row([str(gridit), np.round(dho_error, decimals=2), np.round(diff_error, decimals=2)])

# table.sortby = 'Error'
print(table)

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/450 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/450 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

+------------------------+-----------+------------+
|          Name          | DHO Error | DIFF Error |
+------------------------+-----------+------------+
|       [(2, 500)]       |    3.31   |   37.49    |
|       [(5, 500)]       |    2.98   |   10.08    |
|      [(10, 500)]       |    2.29   |    5.92    |
|      [(20, 500)]       |    1.99   |    5.63    |
|  [(2, 150), (5, 350)]  |    3.38   |   14.15    |
| [(2, 150), (20, 350)]  |    5.2    |    7.33    |
| [(5, 150), (10, 350)]  |    3.3    |    5.2     |
| [(5, 150), (20, 350)]  |    1.62   |    5.03    |
| [(10, 150), (20, 350)] |    1.65   |   10.34    |
|  [(2, 250), (5, 250)]  |    3.51   |   22.84    |
| [(2, 250), (10, 250)]  |    3.6    |   11.07    |
| [(2, 250), (20, 250)]  |    3.17   |   11.16    |
| [(5, 250), (20, 250)]  |    1.64   |    4.84    |
| [(10, 250), (20, 250)] |    1.6    |    6.33    |
|  [(2, 350), (5, 350)]  |    3.41   |   12.56    |
| [(5, 350), (20, 150)]  |    5.8    |   10.77    |
|  [(2, 450)

In [26]:
table.sortby = 'DHO Error'
print(table)

+------------------------+-----------+------------+
|          Name          | DHO Error | DIFF Error |
+------------------------+-----------+------------+
| [(10, 250), (20, 250)] |    1.6    |    6.33    |
| [(5, 150), (20, 350)]  |    1.62   |    5.03    |
| [(5, 250), (20, 250)]  |    1.64   |    4.84    |
| [(10, 150), (20, 350)] |    1.65   |   10.34    |
|      [(20, 500)]       |    1.99   |    5.63    |
|      [(10, 500)]       |    2.29   |    5.92    |
|       [(5, 500)]       |    2.98   |   10.08    |
| [(2, 250), (20, 250)]  |    3.17   |   11.16    |
| [(5, 150), (10, 350)]  |    3.3    |    5.2     |
|       [(2, 500)]       |    3.31   |   37.49    |
|  [(2, 150), (5, 350)]  |    3.38   |   14.15    |
|  [(2, 350), (5, 350)]  |    3.41   |   12.56    |
|  [(2, 250), (5, 250)]  |    3.51   |   22.84    |
| [(2, 250), (10, 250)]  |    3.6    |   11.07    |
| [(2, 150), (20, 350)]  |    5.2    |    7.33    |
| [(5, 350), (20, 150)]  |    5.8    |   10.77    |
|  [(2, 450)

In [27]:
table.sortby = 'DIFF Error'
print(table)

+------------------------+-----------+------------+
|          Name          | DHO Error | DIFF Error |
+------------------------+-----------+------------+
| [(5, 250), (20, 250)]  |    1.64   |    4.84    |
| [(5, 150), (20, 350)]  |    1.62   |    5.03    |
| [(5, 150), (10, 350)]  |    3.3    |    5.2     |
|      [(20, 500)]       |    1.99   |    5.63    |
|      [(10, 500)]       |    2.29   |    5.92    |
| [(10, 250), (20, 250)] |    1.6    |    6.33    |
| [(2, 150), (20, 350)]  |    5.2    |    7.33    |
|       [(5, 500)]       |    2.98   |   10.08    |
| [(10, 150), (20, 350)] |    1.65   |   10.34    |
| [(5, 350), (20, 150)]  |    5.8    |   10.77    |
| [(2, 250), (10, 250)]  |    3.6    |   11.07    |
| [(2, 250), (20, 250)]  |    3.17   |   11.16    |
|  [(2, 350), (5, 350)]  |    3.41   |   12.56    |
|  [(2, 150), (5, 350)]  |    3.38   |   14.15    |
|  [(2, 250), (5, 250)]  |    3.51   |   22.84    |
|       [(2, 500)]       |    3.31   |   37.49    |
|  [(2, 450)

In [25]:
with open('test.csv', 'w', newline='') as f_output:
    f_output.write(table.get_csv_string())