In [None]:
import sys
import os

current_dir = os.getcwd()

project_root = os.path.abspath(os.path.join(os.path.dirname(current_dir), '.'))
if project_root not in sys.path:
    sys.path.append(project_root)

import pinns

# For cleaner output.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

We will solve problem from https://arxiv.org/abs/2007.14527, section 7.3, with fabricated solution:

Find such $u(x)$ so that

$$\frac{\partial^2 u}{\partial t^2} - 4\frac{\partial^2 u}{\partial x^2} = 0 \quad \text{on} \quad [0, 1] \times [0, 1]$$
$$u(t, 0) = u(t, 0) = 0$$
$$u(0, x) = \sin(\pi x) + \frac{1}{2}\sin(4\pi x)$$
$$\frac{\partial u}{\partial t}(0, x) = 0$$

And solution reads as:

$$u(t, x) = \sin(\pi x)\cos(2 \pi t) + \frac{1}{2}\sin(4 \pi x)\cos(8 \pi t)$$

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [None]:
t = torch.linspace(0, 1, 100)
x = torch.linspace(0, 1, 100)
X, T = torch.meshgrid(x, t)

PI = torch.pi
def cos(x):
    return torch.cos(x)
def sin(x):
    return torch.sin(x)

def u(t, x):
    return sin(PI * x) * cos(2 * PI * t) + 0.5 * sin(4 * PI * x)*cos(8 * PI * t)

solution = u(T, X)

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, T, solution, cmap='viridis')

In [None]:
from pinns.samplers import RandomRectangularSampler, ConstantSampler

# Constraints should be just (0, x), (t, 0), (t, 1)
def get_constraints():
    init_pts = torch.hstack(
        [torch.zeros((100, 1), requires_grad=True), x.reshape(-1, 1)]
    )
    left_pts = torch.hstack(
        [t.reshape(-1, 1), torch.zeros((100, 1))]
    )
    right_pts = torch.hstack(
        [t.reshape(-1, 1), torch.ones((100, 1))]
    )
    
    # Let initial values be tuple for u and du.
    init_values = torch.stack([
        u(init_pts[:,0], init_pts[:,1]),
        torch.zeros((100))
        ]).T
    left_values = u(left_pts[:,0], left_pts[:,1]).reshape(-1, 1)
    right_values = u(right_pts[:,0], right_pts[:,1]).reshape(-1, 1)
    
    return ((init_pts, left_pts, right_pts), 
            (init_values, left_values, right_values))
    
pts, vals = get_constraints()
constraints_sampler = ConstantSampler((pts, vals))

domain = {'t': [0, 1], 'x': [0, 1]}
collocation_sampler = RandomRectangularSampler(domain, 1024)

test_sampler = ConstantSampler((
    torch.hstack([X.reshape(-1, 1), T.reshape(-1, 1)]),
    solution.T.reshape(-1, 1)
))

In [None]:
from pinns.derivatives import Derivative

d = Derivative(method='autograd')

def loss(
    cstr_pts, cstr_pred, cstr_vals,
    coll_pts, coll_pred
    ):
    
    init_pts, left_pts, right_pts = cstr_pts
    init_pred, left_pred, right_pred = cstr_pred
    init_vals, left_vals, right_vals = cstr_vals
    
    t, x = coll_pts['t'], coll_pts['x']
    
    def initial_loss(u, tx):
        ut = d(u, tx)[:,[0]]
        preds = torch.hstack([u, ut])
        return torch.mean(torch.square(preds - init_vals))
    
    def left_loss():
        return torch.mean(torch.square(left_pred - left_vals))
    
    def right_loss():
        return torch.mean(torch.square(right_pred - right_vals))
    
    def pde_loss(u, t, x):
        utt = d(u, t, orders = 2)
        uxx = d(u, x, orders = 2)
        return torch.mean(torch.square(utt - 4 * uxx))
    
    return (
        initial_loss(init_pred, init_pts), 
        left_loss(), 
        right_loss(), 
        pde_loss(coll_pred, t, x)
    )

In [None]:
from pinns import Trainer
from pinns.models import FF
from pinns.activations import Sin
from pinns.optimizers import Adam
from pinns.errors import rmse

pinn = FF([2] + [32, 32] + [1], activ=nn.Tanh())

adam = Adam(pinn, lr = 1e-2)

trainer = Trainer(
    loss,
    pinn,
    constraints_sampler,
    collocation_sampler,
    loss_coefs = [0.75, 0.75, 0.75, 0.25],
    test_points_sampler = test_sampler
)

num_iters = 15000
save_every = 25

def make_plot():
    if trainer.iter == 0 or trainer.iter % save_every == 0 or trainer.iter == num_iters:
        preds = pinn.predict(test_sampler()[0]).detach()
        np.save(f'./.temp/{trainer.iter}.npy', preds.numpy())

trainer.train(
    num_iters=num_iters,
    optimizers=[(0, adam)],
    validate_every=1,
    error_metric=rmse,
    at_training_start_callbacks=[make_plot],
    at_epoch_end_callbacks=[make_plot],
    at_training_end_callbacks=[make_plot]
    )

In [None]:
test_pts, test_vals = test_sampler()
test_pts = [
    test_pts[:,0].reshape(100, 100),
    test_pts[:,1].reshape(100, 100)
]

fig = plt.figure(figsize=(12, 5))

ax = fig.add_subplot(121)
ax.plot(trainer.loss_history, label='Loss')
ax.plot(range(0, trainer.iter + 1, 1), trainer.error_history, label='Error')
ax.grid()
ax.set_yscale('log')
ax.legend()

preds = pinn.predict(test_sampler()[0]).detach().reshape(100, 100)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(test_pts[1], test_pts[0], preds, cmap='viridis')

In [None]:
from PIL import Image
import imageio
from joblib import Parallel, delayed

from tqdm.notebook import tqdm_notebook as tqdm

def save_animation(files, path, duration=5, fps=60, loop=0, type='mp4', processors=2, ):
    
    fig = plt.figure(figsize=(6, 6))
    
    def plot(i):
        predictions = np.load(files[i]).reshape(100, 100)
        ax = fig.add_subplot(111, projection='3d')
        # Set plot limits and labels
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.plot_surface(test_pts[1], test_pts[0], predictions, cmap='viridis')
        fig.savefig(f'./.temp/frame_{i}.png', dpi=300)
        fig.clear()
        
    # Number of frames
    num_frames = len(files)

    # Parallelize the plotting function
    Parallel(n_jobs=processors, verbose=4)(delayed(plot)(i) for i in range(num_frames))
    
    if type == 'mp4':
        writer = imageio.get_writer(path, fps=fps)
        for i in range(len(files)):
            writer.append_data(imageio.imread(f'./.temp/frame_{i}.png'))
        writer.close()
        
    if type == 'gif':
        imgs = [Image.open(f'./.temp/frame_{i}.png') for i in range(len(files))]
        imgs[0].save(path, save_all=True, append_images=imgs[1:], duration=duration, fps=fps, loop=loop)
    
files = [f'./.temp/{i}.npy' for i in range(0, trainer.iter, save_every)]
save_animation(files, './.results/wave animation.gif', type='gif', processors=8)