In [None]:
import glob
import os
import re

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import scienceplots
import torch
from mpl_toolkits.mplot3d import Axes3D
from rkan.torch import JacobiRKAN, PadeRKAN
from torch import nn, optim

plt.style.use("science")
mpl.use("pgf")

plt.rcParams.update(
    {"text.usetex": True, "pgf.preamble": r"\usepackage{amssymb} \usepackage{amsmath}"}
)

In [30]:
def dy_dx(y, x):
    return torch.autograd.grad(
        y, x, grad_outputs=torch.ones_like(y), create_graph=True
    )[0]

In [38]:
def closure():
    loss = get_loss(x, y)
    optimizer.zero_grad()
    loss.backward()
    return loss

In [31]:
x = torch.linspace(0, 1, 50, requires_grad=True)
y = torch.linspace(0, 1, 50, requires_grad=True)

x, y = torch.meshgrid(x, y, indexing="ij")
x = x.reshape(-1, 1)
y = y.reshape(-1, 1)

In [34]:
from torch import nn

mlp = nn.Sequential(
    nn.Linear(2, 10),
    JacobiRKAN(4),
    nn.Linear(10, 10),
    JacobiRKAN(4),
    nn.Linear(10, 1),
)

In [35]:
optimizer = optim.LBFGS(list(mlp.parameters()), lr=0.1)

In [36]:
Exact = lambda x, y: -1 / (2 * np.pi**2) * torch.sin(np.pi * x) * torch.sin(np.pi * y)

In [37]:
def get_loss(x, y):
    x_y = torch.cat((x, y), 1)
    x_0 = torch.cat((x, 0 * y), 1)
    x_1 = torch.cat((x, 0 * y + 1), 1)
    z_y = torch.cat((0 * x, y), 1)
    o_y = torch.cat((1 + 0 * x, y), 1)

    u = mlp.forward(x_y)
    u_y = dy_dx(u, y)
    u_x = dy_dx(u, x)
    u_xx = dy_dx(u_x, x)
    u_yy = dy_dx(u_y, y)

    residual_pde = u_xx + u_yy - torch.sin(np.pi * x) * torch.sin(np.pi * y)

    residual_bdy1 = mlp.forward(x_0) - Exact(x, torch.tensor(0))
    residual_bdy2 = mlp.forward(x_1) - Exact(x, torch.tensor(1))
    residual_bdy3 = mlp.forward(z_y) - Exact(torch.tensor(0), y)
    residual_bdy4 = mlp.forward(o_y) - Exact(torch.tensor(1), y)

    loss = 1e6 * (
        (residual_pde**2).mean()
        + (residual_bdy1**2).mean()
        + (residual_bdy2**2).mean()
        + (residual_bdy3**2).mean()
        + (residual_bdy4**2).mean()
    )
    return loss


get_loss(x, y)

tensor(1409404., grad_fn=<MulBackward0>)

In [39]:
losses = []

for i in range(1, 50):
    loss = get_loss(x, y)
    optimizer.step(closure)
    losses.append(loss.detach().numpy())
    if i % 1 == 0:
        print("Epoch %3d: Current loss: %.2e" % (i, losses[-1]))

Epoch   1: Current loss: 1.41e+06
Epoch   2: Current loss: 7.94e+04
Epoch   3: Current loss: 4.37e+03
Epoch   4: Current loss: 7.12e+02
Epoch   5: Current loss: 3.16e+02
Epoch   6: Current loss: 1.84e+02
Epoch   7: Current loss: 9.23e+01
Epoch   8: Current loss: 5.99e+01
Epoch   9: Current loss: 4.48e+01
Epoch  10: Current loss: 2.40e+01
Epoch  11: Current loss: 1.97e+01
Epoch  12: Current loss: 1.67e+01
Epoch  13: Current loss: 1.45e+01
Epoch  14: Current loss: 1.33e+01
Epoch  15: Current loss: 1.30e+01
Epoch  16: Current loss: 1.27e+01
Epoch  17: Current loss: 1.20e+01
Epoch  18: Current loss: 1.14e+01
Epoch  19: Current loss: 1.07e+01
Epoch  20: Current loss: 1.04e+01
Epoch  21: Current loss: 1.02e+01
Epoch  22: Current loss: 9.75e+00
Epoch  23: Current loss: 9.23e+00
Epoch  24: Current loss: 8.78e+00
Epoch  25: Current loss: 8.27e+00
Epoch  26: Current loss: 7.98e+00
Epoch  27: Current loss: 7.85e+00
Epoch  28: Current loss: 7.89e+00
Epoch  29: Current loss: 7.71e+00
Epoch  30: Cur

In [43]:
Nx, Nt = 31, 25
x = torch.linspace(0, 1, Nx)
t = torch.linspace(0, 1, Nt)

x, t = torch.meshgrid(x, t, indexing="ij")
x = x.reshape(-1, 1)
t = t.reshape(-1, 1)
x_t = torch.cat((x, t), 1)

exact = Exact(x, t).reshape(Nx, Nt)
predict = mlp.forward(x_t).reshape(Nx, Nt)
error = exact - predict

MAE = torch.abs(error).mean()

print("Mean Absolute Error: %.2e" % MAE)

Mean Absolute Error: 2.59e-04


In [44]:
import matplotlib.pyplot as plt
import numpy as np

x = x.reshape(Nx, Nt)
t = t.reshape(Nx, Nt)

fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(131, projection="3d")
ax.plot_surface(x, t, predict.detach().numpy(), cmap="viridis")
ax.view_init(10, 45)

ax = fig.add_subplot(132, projection="3d")
ax.plot_surface(x, t, exact.detach().numpy(), cmap="viridis")
ax.view_init(10, 45)
ax = fig.add_subplot(133, projection="3d")
ax.plot_surface(x, t, error.detach().numpy(), cmap="viridis")
ax.view_init(10, 45)
fig = plt.figure(figsize=(15, 5))
ax = fig.add_subplot(111)
ax.plot(np.log10(losses), "c", label="Loss")

[<matplotlib.lines.Line2D at 0x76644dfe3a50>]

In [50]:
fig = plt.figure(figsize=(12, 7))

ax = fig.add_subplot(111, projection="3d")

ax.plot_surface(x, t, predict.detach().numpy(), cmap="viridis")

ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.xaxis.pane.set_edgecolor("w")
ax.yaxis.pane.set_edgecolor("w")
ax.zaxis.pane.set_edgecolor("w")
ax.set_yticks([0, 0.5, 1])
ax.set_xlabel(r"$\xi_1$")
ax.set_ylabel(r"$\xi_2$")
ax.set_zlabel(r"$\hat{F}(\xi_1,\xi_2)$")
ax.view_init(elev=15, azim=-260)

plt.savefig(
    "elliptic-prediction.pdf",
    bbox_inches="tight",
    pad_inches=0,
)

fig = plt.figure(figsize=(6, 7))
ax = fig.add_subplot(111, projection="3d")

ax.plot_surface(x, t, error.detach().numpy(), cmap="viridis")

ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

ax.xaxis.pane.set_edgecolor("w")
ax.yaxis.pane.set_edgecolor("w")
ax.zaxis.pane.set_edgecolor("w")
ax.set_yticks([0, 0.5, 1])
ax.set_xlabel(r"$\xi_1$")
ax.set_ylabel(r"$\xi_2$")
ax.set_zlabel(r"$\mathfrak{R}(\xi_1,\xi_2)$")
ax.view_init(elev=25, azim=-260)


plt.savefig(
    "elliptic-residual.pdf",
    bbox_inches="tight",
    pad_inches=0,
)