# Import Packages

In [14]:
# Importing necessary libraries
import torch
import gc
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np
import seaborn as sns
import scienceplots
import matplotlib.pyplot as plt

import traceback
from collections import OrderedDict
from tqdm.auto import tqdm

# Set the random seed for reproducibility
seed = 20230808
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Set the device to GPU if available, otherwise use CPU
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

# Perform garbage collection and empty the GPU cache in PyTorch
gc.collect()
torch.cuda.empty_cache()


# Model

In [15]:
class NN(nn.Module):
    """
    A simple neural network class.

    Args:
        input_size (int): The number of input features.
        hidden_size (int): The number of hidden units in each hidden layer.
        output_size (int): The number of output features.
        depth (int): The number of hidden layers.
        ac (torch.nn.Module): The activation function to use for each hidden layer.
    """

    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        depth,
        ac=torch.nn.Tanh,
    ):
        super(NN, self).__init__()

        layers = [('input', torch.nn.Linear(input_size, hidden_size))]
        layers.append(('input_activation', ac()))
        for i in range(depth):
            layers.append(
                ('hidden_%d' % i, torch.nn.Linear(hidden_size, hidden_size))
            )
            layers.append(('activation_%d' % i, ac()))
        layers.append(('output', torch.nn.Linear(hidden_size, output_size)))

        layerDict = OrderedDict(layers)
        self.layers = torch.nn.Sequential(layerDict)

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, input_size).

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, output_size).
        """
        return self.layers(x)


# Training Loop

A one-dimensional wave equation is chosen for our experiments, which, in mathematical form, is defined as follows:

$$ u_{tt} - u_{xx} = 0$$

for this wave equation, its initial conditions and the homogeneous Dirichlet boundary conditions are given, as follows:

$$ u(0, x) = \frac{1}{2} \sin (\pi x)$$
$$ u_t(0, x) = \pi \sin (3 \pi x)$$
$$ u(t, 0) =  u(t, 1) = 0 $$

In [16]:
def init_weights(layer):
    """
    Initializes the weights of a layer with Xavier normal initialization.
    Args:
        layer (torch.nn.Module): The layer to initialize.
    """
    if isinstance(layer, nn.Linear):
        torch.nn.init.xavier_normal_(layer.weight)


class Net:
    def __init__(self):
        self.model = NN(
            input_size=2,
            hidden_size=100,
            output_size=1,
            depth=6,
            ac=torch.nn.Tanh
        ).to(device)

        # use the Glorot normal initializer for initialization
        self.model.apply(init_weights)

        # The initial conditions, boundary conditions
        # with Nu approximating 300
        self.h = 0.1
        self.k = 0.1
        x = torch.arange(0, 1 + self.h, self.h)
        t = torch.arange(0, 1 + self.k, self.k)
        # x[0] = 0, x[-1] = 1, t[0] = 0
        bc1 = torch.stack(torch.meshgrid(x[0], t)).reshape(2, -1).T
        bc2 = torch.stack(torch.meshgrid(x[-1], t)).reshape(2, -1).T
        ic = torch.stack(torch.meshgrid(x, t[0])).reshape(2, -1).T
        self.X_train = torch.cat([bc1, bc2, ic])
        y_bc1 = torch.zeros(len(bc1))
        y_bc2 = torch.zeros(len(bc2))
        y_ic = 1/2*torch.sin(math.pi * ic[:, 0])
        self.y_train = torch.cat([y_bc1, y_bc2, y_ic])
        self.y_train = self.y_train.unsqueeze(1)

        # Data in the space-time domain
        # with Nf approximating 40000
        self.h = 0.005
        self.k = 0.005
        x = torch.arange(self.h, 1 + self.h, self.h)
        t = torch.arange(self.k, 1 + self.k, self.k)
        self.X = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T

        # Device Consistency
        self.X = self.X.to(device)
        self.X_train = self.X_train.to(device)
        self.y_train = self.y_train.to(device)
        self.X.requires_grad = True

        # Logger
        self.iter = 1

        # Loss Function
        self.criterion = torch.nn.MSELoss()

        # Two Optimizer
        self.optimizer = torch.optim.LBFGS(
            self.model.parameters(),
            lr=0.001,
            max_iter=50000, 
            max_eval=50000, 
            history_size=50,
            tolerance_grad=1e-7, 
            tolerance_change=1.0 * np.finfo(float).eps,
            line_search_fn="strong_wolfe")

        self.adam = torch.optim.Adam(self.model.parameters())

    def loss_func(self):
        self.adam.zero_grad()
        self.optimizer.zero_grad()

        # loss using observations of initial and boundary conditions
        y_pred = self.model(self.X_train)
        loss_data = self.criterion(y_pred, self.y_train)

        # loss based on partial differential equations
        u = self.model(self.X)

        du_dX = torch.autograd.grad(
            inputs=self.X,
            outputs=u,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        du_dt = du_dX[:, 1]
        du_dx = du_dX[:, 0]

        du_dxx = torch.autograd.grad(
            inputs=self.X,
            outputs=du_dX,
            grad_outputs=torch.ones_like(du_dX),
            retain_graph=True,
            create_graph=True
        )[0][:, 0]

        du_dtt = torch.autograd.grad(
            inputs=self.X,
            outputs=du_dt,
            grad_outputs=torch.ones_like(du_dt),
            retain_graph=True,
            create_graph=True
        )[0][:, 1]

        loss_pde = self.criterion(du_dtt - du_dxx, du_dxx*0)

        loss = loss_pde + loss_data
        loss.backward()
        if self.iter % 100 == 0:
            print(self.iter, loss.item())
        self.iter = self.iter + 1
        return loss

    def train(self):
        self.model.train()
        # L-BFGS 30,000 epochs
        for _ in tqdm(range(20)):
            self.optimizer.step(self.loss_func)
            # continued the optimization using Adam
            for _ in range(1500):
                self.adam.step(self.loss_func)

    def eval_(self):
        self.model.eval()


# Training

In [4]:
net = Net()
net.train()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


  0%|          | 0/20 [00:00<?, ?it/s]

100 0.02055489830672741
200 0.019409509375691414
300 0.018849054351449013
400 0.011236368678510189
500 0.005986572243273258
600 0.004425127990543842
700 0.002523630391806364
800 0.001528450520709157
900 0.0013338442659005523


  0%|          | 0/1500 [00:00<?, ?it/s]

1000 0.2565414309501648
1100 0.01664217747747898
1200 0.00883091613650322
1300 0.0019051822600886226
1400 0.001368343597277999
1500 0.0012860839487984776
1600 0.001252887537702918
1700 0.001237573567777872
1800 0.0012410100316628814
1900 0.007598113268613815
2000 0.0012101908214390278
2100 0.001211515162140131


KeyboardInterrupt: 

In [5]:
torch.save(net.model.state_dict(), 'model.ckpt')


# Evaluation

In [17]:
import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt
import scienceplots
import plotly.express as px

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

# plt.style.use(['ipynb', 'use_mathtext', 'colors5-light', 'science'])
# sns.set_style('whitegrid')
# sns.set_palette('RdBu')
# sns.set(
#     rc={'text.usetex': True},
#     font='serif',
#     font_scale=1.2
# )

# matplotlib.rcParams['font.sans-serif'] = ['SimHei']
# matplotlib.rcParams['font.serif'] = ['SimHei']
# sns.set_style('darkgrid', {'font.sans-serif': ['simhei', 'Arial']})


In [18]:
h = 0.001
k = 0.001
x = torch.arange(0, 1, h)
t = torch.arange(0, 1, k)
X = torch.stack(torch.meshgrid(x, t)).reshape(2, -1).T
X = X.to(device)


In [19]:
X.shape


torch.Size([1000000, 2])

In [20]:
def ground_truth(X):
    x = X[:, 0]
    t = X[:, 1]
    y = 1/2*torch.sin(torch.pi*x)*torch.cos(torch.pi*t)+1 / \
        3*torch.sin(3*torch.pi*x)*torch.sin(3*torch.pi*t)
    return y


In [23]:
model = NN(input_size=2,
           hidden_size=100,
           output_size=1,
           depth=6,
           ac=torch.nn.Tanh).to(device)
model.load_state_dict(torch.load('model.ckpt'))
model.eval()
with torch.no_grad():
    # y_pred = ground_truth(X).reshape(len(x), len(t)).detach().cpu()
    y_pred = model(X).reshape(len(x), len(t)).detach().cpu()


In [None]:
plt.figure(figsize=(5, 3), dpi=150)
sns.heatmap(y_pred, cmap='jet')
plt.show()


In [26]:
y_pred.shape

torch.Size([1000, 1000])

In [None]:
x = torch.arange(0, 1, h).unsqueeze(dim=1).to(device)
t = (torch.ones(x.shape)*0.5).to(device)

plt.scatter(x.detach().cpu(), model(torch.cat([x, t], dim=1)).detach().cpu())