In [1]:
## all the imports
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import LBFGS, AdamW
from tqdm import tqdm
from src.util import *
from src.model.setpinns import SetPinns

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_warn_always(False)

In [4]:
# xavier init and functions to get the analytical sol
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)


def h(x):
    return np.exp(-((x - np.pi) ** 2) / (2 * (np.pi / 4) ** 2))


def u_ana(x, t):
    return h(x) * np.exp(5 * t) / (h(x) * np.exp(5 * t) + 1 - h(x))

In [5]:
# Get data
res_points = 50
test_points = 101
val_points = 21
set_dim=2
device = "cuda:0"

res, b_left, b_right, b_upper, b_lower = get_data(
    [0, 2 * np.pi], [0, 1], res_points, res_points
)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], test_points, test_points)
# val
res_val, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], val_points, val_points)
u = torch.from_numpy(u_ana(res_val[:, 0], res_val[:, 1])).reshape(-1).to(device)
res_val = torch.tensor(res_val, dtype=torch.float32, requires_grad=True).to(device)
x_val, t_val = res_val[:, 0:1], res_val[:, 1:2]
x_val, t_val = x_val.reshape(-1, 1, 1), t_val.reshape(-1, 1, 1)


res = res.reshape(res_points, res_points, 2)
res = (
    res.reshape(
        res.shape[0] // set_dim, set_dim, res.shape[1] // set_dim, set_dim, 2
    )
    .swapaxes(1, 2)
    .reshape(-1, set_dim, set_dim, 2)
)
res = res.reshape(res.shape[0], -1, 2)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:, :, 0], res[:, :, 1]
x_left, t_left = b_left[:, 0], b_left[:, 1]
x_right, t_right = b_right[:, 0], b_right[:, 1]
x_upper, t_upper = b_upper[:, 0], b_upper[:, 1]
x_lower, t_lower = b_lower[:, 0], b_lower[:, 1]
(
    x_res,
    t_res,
    x_left,
    t_left,
    x_right,
    t_right,
    x_upper,
    t_upper,
    x_lower,
    t_lower,
) = (
    x_res.unsqueeze(-1),
    t_res.unsqueeze(-1),
    make_set_sequence(x_left, set_dim),
    make_set_sequence(t_left, set_dim),
    make_set_sequence(x_right, set_dim),
    make_set_sequence(t_right, set_dim),
    make_set_sequence(x_upper, set_dim),
    make_set_sequence(t_upper, set_dim),
    make_set_sequence(x_lower, set_dim),
    make_set_sequence(t_lower, set_dim),
)

based on set size, now we have created sets of residual points as well as boundary and initial points. Note that the points in a set are close to each other. By changing the hyper-parameters, you can vary on how "close" these points can be.

In [6]:
x_left

tensor([[[0.0000],
         [0.1282]],

        [[0.2565],
         [0.3847]],

        [[0.5129],
         [0.6411]],

        [[0.7694],
         [0.8976]],

        [[1.0258],
         [1.1541]],

        [[1.2823],
         [1.4105]],

        [[1.5387],
         [1.6670]],

        [[1.7952],
         [1.9234]],

        [[2.0517],
         [2.1799]],

        [[2.3081],
         [2.4363]],

        [[2.5646],
         [2.6928]],

        [[2.8210],
         [2.9493]],

        [[3.0775],
         [3.2057]],

        [[3.3339],
         [3.4622]],

        [[3.5904],
         [3.7186]],

        [[3.8468],
         [3.9751]],

        [[4.1033],
         [4.2315]],

        [[4.3598],
         [4.4880]],

        [[4.6162],
         [4.7444]],

        [[4.8727],
         [5.0009]],

        [[5.1291],
         [5.2574]],

        [[5.3856],
         [5.5138]],

        [[5.6420],
         [5.7703]],

        [[5.8985],
         [6.0267]],

        [[6.1550],
         [6.2832]]], 

In [7]:
t_lower

tensor([[[0.0000],
         [0.0204]],

        [[0.0408],
         [0.0612]],

        [[0.0816],
         [0.1020]],

        [[0.1224],
         [0.1429]],

        [[0.1633],
         [0.1837]],

        [[0.2041],
         [0.2245]],

        [[0.2449],
         [0.2653]],

        [[0.2857],
         [0.3061]],

        [[0.3265],
         [0.3469]],

        [[0.3673],
         [0.3878]],

        [[0.4082],
         [0.4286]],

        [[0.4490],
         [0.4694]],

        [[0.4898],
         [0.5102]],

        [[0.5306],
         [0.5510]],

        [[0.5714],
         [0.5918]],

        [[0.6122],
         [0.6327]],

        [[0.6531],
         [0.6735]],

        [[0.6939],
         [0.7143]],

        [[0.7347],
         [0.7551]],

        [[0.7755],
         [0.7959]],

        [[0.8163],
         [0.8367]],

        [[0.8571],
         [0.8776]],

        [[0.8980],
         [0.9184]],

        [[0.9388],
         [0.9592]],

        [[0.9796],
         [1.0000]]], 

In [8]:
x_res

tensor([[[0.0000],
         [0.1282],
         [0.0000],
         [0.1282]],

        [[0.2565],
         [0.3847],
         [0.2565],
         [0.3847]],

        [[0.5129],
         [0.6411],
         [0.5129],
         [0.6411]],

        ...,

        [[5.6420],
         [5.7703],
         [5.6420],
         [5.7703]],

        [[5.8985],
         [6.0267],
         [5.8985],
         [6.0267]],

        [[6.1550],
         [6.2832],
         [6.1550],
         [6.2832]]], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [9]:
# now initialize the model and few important variables. One can use any variant of the Tranformer model.
# define setpinns
model = SetPinns(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)
# apply xavier init
model.apply(init_weights)
best_val = np.inf
loss_track = []
val_track = []
best_model_weights = None

In [10]:
# now perform training with AdamW
# adam training
optim = AdamW(model.parameters(), lr=3e-4)
print(f"Adam Training")
for i in range(100):
    model.train()
    pred_res = model(x_res, t_res)
    pred_left = model(x_left, t_left)
    pred_upper = model(x_upper, t_upper)
    pred_lower = model(x_lower, t_lower)
    u_t = torch.autograd.grad(
        pred_res,
        t_res,
        grad_outputs=torch.ones_like(pred_res),
        retain_graph=True,
        create_graph=True,
    )[0]
    loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
    loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
    loss_ic = torch.mean(
        (
            pred_left[:, 0]
            - torch.exp(
                -((x_left[:, 0] - torch.pi) ** 2) / (2 * (torch.pi / 4) ** 2)
            )
        )
        ** 2
    )

    loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])

    loss = loss_res + loss_bc + loss_ic
    optim.zero_grad()
    loss.backward()
    optim.step()

    # validation
    model.eval()
    pred = model(x_val, t_val)
    pred = pred.reshape(-1)
    r = F.mse_loss(pred, u).item()
    val_track.append(r)
    if r < best_val:
        print(f"Best val reached at {i}")
        best_val = r
        best_model_weights = model.state_dict()

Adam Training
Best val reached at 0
Best val reached at 1
Best val reached at 6
Best val reached at 29
Best val reached at 30
Best val reached at 31
Best val reached at 32
Best val reached at 33
Best val reached at 34
Best val reached at 35
Best val reached at 60
Best val reached at 61
Best val reached at 62
Best val reached at 63
Best val reached at 64
Best val reached at 65
Best val reached at 66


In [11]:
# now tune the model using LBFGS
model.load_state_dict(best_model_weights)
model.to(device)
optim = LBFGS(model.parameters(), line_search_fn="strong_wolfe")
for i in range(1000):
    model.train()

    def closure():
        pred_res = model(x_res, t_res)
        pred_left = model(x_left, t_left)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)
        u_t = torch.autograd.grad(
            pred_res,
            t_res,
            grad_outputs=torch.ones_like(pred_res),
            retain_graph=True,
            create_graph=True,
        )[0]
        loss_res = torch.mean((u_t - 5 * pred_res * (1 - pred_res)) ** 2)
        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
        loss_ic = torch.mean(
            (
                pred_left[:, 0]
                - torch.exp(
                    -((x_left[:, 0] - torch.pi) ** 2) / (2 * (torch.pi / 4) ** 2)
                )
            )
            ** 2
        )

        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])

        loss = loss_res + loss_bc + loss_ic
        optim.zero_grad()
        loss.backward()
        return loss

    optim.step(closure)

    # validation
    model.eval()
    pred = model(x_val, t_val)
    pred = pred.reshape(-1)
    r = F.mse_loss(pred, u).item()
    val_track.append(r)
    if r < best_val:
        print(f"Best val reached at {i}")
        best_val = r
        best_model_weights = model.state_dict()

Best val reached at 0
Best val reached at 1
Best val reached at 2
Best val reached at 35
Best val reached at 36
Best val reached at 37
Best val reached at 38
Best val reached at 39
Best val reached at 41
Best val reached at 42
Best val reached at 43
Best val reached at 45
Best val reached at 46
Best val reached at 47
Best val reached at 51
Best val reached at 52
Best val reached at 53
Best val reached at 54
Best val reached at 59
Best val reached at 60
Best val reached at 61
Best val reached at 65
Best val reached at 66


In [12]:
model.load_state_dict(best_model_weights)
model.to(device)
model.eval()
res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(
        device
    )
x_test, t_test = res_test[:, 0:1], res_test[:, 1:2]
x_test, t_test = x_test.reshape(-1, 1, 1), t_test.reshape(-1, 1, 1)
with torch.no_grad():
    pred = model(x_test, t_test)
    pred = pred.cpu().detach().numpy()
pred = pred.reshape(test_points, test_points)
res_test, _, _, _, _ = get_data([0, 2 * np.pi], [0, 1], test_points, test_points)
u = u_ana(res_test[:, 0], res_test[:, 1]).reshape(test_points, test_points)

rl1 = np.sum(np.abs(u - pred)) / np.sum(np.abs(u))
rl2 = np.sqrt(np.sum((u - pred) ** 2) / np.sum(u**2))

print("rRMSE: {:4f}".format(rl2))

rRMSE: 0.042286
