In [1]:
## all the imports
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import LBFGS, AdamW
from tqdm import tqdm
from src.util import *
from src.model.pinnsformer import PINNsformer

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_warn_always(False)

In [4]:
# xavier init and functions to get the analytical sol
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)


def h(x):
    return np.exp(-((x - np.pi) ** 2) / (2 * (np.pi / 4) ** 2))


def u_ana(x, t):
    return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x))

In [5]:
# Get data
res_points = 50
test_points = 101
val_points = 21
num_step=5
step=1e-4
device = "cuda:0"

res, b_left, b_right, b_upper, b_lower = get_data(
    [0, 2 * np.pi], [0, 1], res_points, res_points
)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], test_points, test_points)
# val
res_val, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], val_points, val_points)
u = torch.from_numpy(u_ana(res_val[:, 0], res_val[:, 1])).reshape(-1).to(device)
res_val = make_time_sequence(res_val, num_step=num_step, step=step)
res_val = torch.tensor(res_val, dtype=torch.float32, requires_grad=True).to(device)
x_val, t_val = res_val[:, :, 0:1], res_val[:, :, 1:2]

res = make_time_sequence(res, num_step=num_step, step=step)
b_left = make_time_sequence(b_left, num_step=num_step, step=step)
b_right = make_time_sequence(b_right, num_step=num_step, step=step)
b_upper = make_time_sequence(b_upper, num_step=num_step, step=step)
b_lower = make_time_sequence(b_lower, num_step=num_step, step=step)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, :, 0:1], res[:, :, 1:2]
x_left, t_left = b_left[:, :, 0:1], b_left[:, :, 1:2]
x_right, t_right = b_right[:, :, 0:1], b_right[:, :, 1:2]
x_upper, t_upper = b_upper[:, :, 0:1], b_upper[:, :, 1:2]
x_lower, t_lower = b_lower[:, :, 0:1], b_lower[:, :, 1:2]

In [6]:
# now initialize the model and few important variables. One can use any variant of the Tranformer model.
# define setpinns
model = PINNsformer(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)
# apply xavier init
model.apply(init_weights)
best_val = np.inf
loss_track = []
val_track = []
best_model_weights = None

In [7]:
# now perform training with AdamW
# adam training
optim = AdamW(model.parameters(), lr=3e-4)
print(f"Adam Training")
for i in range(100):
    model.train()
    pred_res = model(x_res, t_res)
    pred_left = model(x_left, t_left)
    pred_upper = model(x_upper, t_upper)
    pred_lower = model(x_lower, t_lower)
    u_t = torch.autograd.grad(
        pred_res,
        t_res,
        grad_outputs=torch.ones_like(pred_res),
        retain_graph=True,
        create_graph=True,
    )[0]
    loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
    loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
    loss_ic = torch.mean(
        (
            pred_left[:, 0]
            - torch.exp(
                -((x_left[:, 0] - torch.pi) ** 2) / (2 * (torch.pi / 4) ** 2)
            )
        )
        ** 2
    )

    loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])

    loss = loss_res + loss_bc + loss_ic
    optim.zero_grad()
    loss.backward()
    optim.step()

    # validation
    model.eval()
    pred = model(x_val, t_val)[:, 0:1]
    pred = pred.reshape(-1)
    r = F.mse_loss(pred, u).item()
    val_track.append(r)
    if r < best_val:
        print(f"Best val reached at {i}")
        best_val = r
        best_model_weights = model.state_dict()

Adam Training
Best val reached at 0
Best val reached at 1
Best val reached at 7
Best val reached at 19


In [8]:
# now tune the model using LBFGS
model.load_state_dict(best_model_weights)
model.to(device)
optim = LBFGS(model.parameters(), line_search_fn="strong_wolfe")
for i in range(1000):
    model.train()

    def closure():
        pred_res = model(x_res, t_res)
        pred_left = model(x_left, t_left)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)
        u_t = torch.autograd.grad(
            pred_res,
            t_res,
            grad_outputs=torch.ones_like(pred_res),
            retain_graph=True,
            create_graph=True,
        )[0]
        loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
        loss_ic = torch.mean(
            (
                pred_left[:, 0]
                - torch.exp(
                    -((x_left[:, 0] - torch.pi) ** 2) / (2 * (torch.pi / 4) ** 2)
                )
            )
            ** 2
        )

        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])

        loss = loss_res + loss_bc + loss_ic
        optim.zero_grad()
        loss.backward()
        return loss

    optim.step(closure)

    # validation
    model.eval()
    pred = model(x_val, t_val)[:, 0:1]
    pred = pred.reshape(-1)
    r = F.mse_loss(pred, u).item()
    val_track.append(r)
    if r < best_val:
        print(f"Best val reached at {i}")
        best_val = r
        best_model_weights = model.state_dict()

In [9]:
model.load_state_dict(best_model_weights)
model.to(device)
model.eval()
res_test = make_time_sequence(res_test, num_step=num_step, step=step)
res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(
        device
    )
x_test, t_test = res_test[:, :, 0:1], res_test[:, :, 1:2]
with torch.no_grad():
    pred = model(x_test, t_test)[:, 0:1]
    pred = pred.cpu().detach().numpy()
pred = pred.reshape(test_points, test_points)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], test_points, test_points)
u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(test_points, test_points)

rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u**2))

print("rRMSE: {:4f}".format(rl2))

rRMSE: 0.978569
