In [1]:
import torch
import torch.nn as nn
from model import dataset
from torch.autograd import grad
import numpy as np
from model.model import VanillaPDETransformer
from collections import defaultdict
from bcics.boundary_conditions import DirichletBC
from bcics.initial_conditions import IC
from tqdm import tqdm

In [2]:
def relative_l2_error(A, B):
    l2_error = torch.norm(A - B)
    l2_norm_A = torch.norm(A)
    
    # To avoid division by zero, add a small constant (e.g., 1e-8)
    epsilon = 1e-8
    relative_error = l2_error / (l2_norm_A + epsilon)
    
    return relative_error

In [3]:
def gen_testdata():
    data = np.load("./data/Burgers.npz")
    t, x, exact = data["t"], data["x"], data["usol"].T
    xx, tt = np.meshgrid(x, t)
    X = np.vstack((np.ravel(xx), np.ravel(tt))).T
    y = exact.flatten()[:, None]
    return X, y

In [4]:
X, y = gen_testdata()
print(X.shape)
print(y.shape)

(25600, 2)
(25600, 1)


In [5]:
def pde(x, y):
    dy_x = grad(y, x, grad_outputs=torch.ones_like(y), retain_graph=True, create_graph=True)[0][:, 0:1]
    dy_t = grad(y, x, grad_outputs=torch.ones_like(y), retain_graph=True, create_graph=True)[0][:, 1:]
    dy_xx = grad(dy_x, x, grad_outputs=torch.ones_like(y), retain_graph=True, create_graph=True)[0][:, 0:1]
    return dy_t + y * dy_x - 0.01 / np.pi * dy_xx

In [6]:
geom = [(-1, 1), (0, 0.99)]

bc_1 = DirichletBC(geom, boundary_dim=0, boundary_point=-1, time_dim=True, func=lambda x: torch.zeros(x.shape[0]).to("cuda"))
bc_2 = DirichletBC(geom, boundary_dim=0, boundary_point=1, time_dim=True, func=lambda x: torch.zeros(x.shape[0]).to("cuda"))
ic = IC(geom, lambda x: -torch.sin(np.pi * x[:, 0:1]))

In [242]:
config = defaultdict(lambda: None,
                            num_feats=2,
                            pos_dim=2,
                            n_targets=1,
                            n_hidden=64,
                            num_feat_layers=2,
                            num_encoder_layers=1,
                            n_head=4,
                            # pred_len=0,
                            dim_feedforward=64,
                            attention_type='softmax',  # no softmax
                            xavier_init=1e-4,
                            diagonal_weight=1e-2,
                            symmetric_init=False,
                            layer_norm=False,
                            attn_norm=False,
                            batch_norm=True,
                            spacial_residual=True,
                            return_attn_weight=False,
                            seq_len=None,
                            activation='silu',
                            decoder_type='pointwise',
                            # freq_dim=64,
                            num_regressor_layers=2,
                            # fourier_modes=16,
                            spacial_dim=2,
                            spacial_fc=True,
                            dropout=0.05,)

In [243]:
net = VanillaPDETransformer(**config)

In [244]:
pinn_dataset = dataset.pinn_collect_dataset(num_collect=2048, geom=geom, time_dim=True,
                 space_distribution='random', time_distribution='uniform', given_data=False)

In [245]:
collect_data = pinn_dataset.prepare_collection_data()

In [246]:
boundary_data_1 = np.random.uniform(0, 1, (80, 1))
start, end= geom[1]
boundary_data_1[:, 0] = boundary_data_1[:, 0] * (end - start) + start
new_col = np.full((boundary_data_1.shape[0], 1), -1.)
boundary_data_1 = np.hstack((new_col, boundary_data_1))
boundary_data_1 = torch.tensor(boundary_data_1, dtype=torch.float32)

In [247]:
boundary_data_2 = np.random.uniform(0, 1, (80, 1))
start, end= geom[1]
boundary_data_2[:, 0] = boundary_data_2[:, 0] * (end - start) + start
new_col = np.full((boundary_data_2.shape[0], 1), 1.)
boundary_data_2 = np.hstack((new_col, boundary_data_2))
boundary_data_2 = torch.tensor(boundary_data_2, dtype=torch.float32)

In [248]:
initial_data = np.random.uniform(0, 1, (160, 1))
start, end= geom[0]
initial_data[:, 0] = initial_data[:, 0] * (end - start) + start
new_col = np.full((initial_data.shape[0], 1), 0.)
initial_data = np.hstack((initial_data, new_col))
initial_data = torch.tensor(initial_data, dtype=torch.float32)

In [249]:
data_list = [collect_data,boundary_data_1, boundary_data_2, initial_data]
all_data = torch.cat(data_list, dim=0)

In [250]:
sizes = [tensor.size(0) for tensor in data_list]
begin_indices = [sum(sizes[:i]) for i in range(len(sizes))]
end_indices = [sum(sizes[:i+1]) for i in range(len(sizes))]

In [251]:
begin_indices

[0, 2048, 2128, 2208]

In [252]:
net.to("cuda:0")

VanillaPDETransformer(
  (dpo): Dropout(p=0.05, inplace=False)
  (feat_extract): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): SiLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): SiLU()
  )
  (encoder_layers): ModuleList(
    (0): SimpleTransformerEncoderLayer(
      (attn): SimpleAttention(
        (linears): ModuleList(
          (0-2): 3 x Linear(in_features=64, out_features=64, bias=True)
        )
        (norm_K): ModuleList(
          (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (norm_Q): ModuleList(
          (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (fc): Linear(in_features=72, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (lr1): Linear(in_features=64, out_features=64, bias=True)
        (activation

In [253]:
max_iter_adam = 10000
lr = 1e-3
optimizer = torch.optim.Adam(net.parameters())
MSE = nn.MSELoss()

In [254]:
bc_error_1 = bc_1.error(input_data[begin_indices[1]:end_indices[1]],output[begin_indices[1]:end_indices[1]])
loss_bc_1 = MSE(bc_error_1, torch.zeros_like(bc_error_1))

In [255]:
bc_error_2 = bc_2.error(input_data[begin_indices[2]:end_indices[2]],output[begin_indices[2]:end_indices[2]])
loss_bc_2 = MSE(bc_error_2, torch.zeros_like(bc_error_2))

In [256]:
ic_error = ic.error(input_data[begin_indices[3]:end_indices[3]],output[begin_indices[3]:end_indices[3]])
loss_ic = MSE(ic_error, torch.zeros_like(ic_error))

In [257]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

class Scheduler(_LRScheduler):
    def __init__(self, 
                 optimizer: Optimizer,
                 dim_embed: int,
                 warmup_steps: int,
                 last_epoch: int=-1,
                 verbose: bool=False) -> None:

        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)
        
    def get_lr(self) -> float:
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups


def calc_lr(step, dim_embed, warmup_steps):
    return dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))

In [258]:
scheduler = Scheduler(optimizer, 64, 4000)

In [259]:
for i in range(1, max_iter_adam+1):
    input_data = all_data.to('cuda:0') 
    input_data.requires_grad=True
    output = net(input_data)
    pde_results = pde(input_data, output)
    loss_res = MSE(pde_results, torch.zeros_like(pde_results))
    
    bc_error_1 = bc_1.error(input_data[begin_indices[1]:end_indices[1]],output[begin_indices[1]:end_indices[1]])
    loss_bc_1 = MSE(bc_error_1, torch.zeros_like(bc_error_1))
    
    bc_error_2 = bc_2.error(input_data[begin_indices[2]:end_indices[2]],output[begin_indices[2]:end_indices[2]])
    loss_bc_2 = MSE(bc_error_2, torch.zeros_like(bc_error_2))
    
    ic_error = ic.error(input_data[begin_indices[3]:end_indices[3]],output[begin_indices[3]:end_indices[3]])
    loss_ic = MSE(ic_error, torch.zeros_like(ic_error))
    
    loss= loss_res+loss_bc_1+loss_bc_2+loss_ic
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    if i%100 ==0:
        print('Epoch:%d \t Loss:%.6f \t Loss_pde:%.6f \t Loss_bc_1:%.6f \t Loss_bc_2:%.6f \t Loss_ic:%.6f'%(i,loss.detach().cpu().numpy(), loss_res.detach().cpu().numpy(), loss_bc_1.detach().cpu().numpy(), loss_bc_2.detach().cpu().numpy(), loss_ic.detach().cpu().numpy()))

Epoch:100 	 Loss:0.499068 	 Loss_pde:0.000296 	 Loss_bc_1:0.007553 	 Loss_bc_2:0.013723 	 Loss_ic:0.477495
Epoch:200 	 Loss:0.466688 	 Loss_pde:0.014335 	 Loss_bc_1:0.021153 	 Loss_bc_2:0.037860 	 Loss_ic:0.393341
Epoch:300 	 Loss:0.272905 	 Loss_pde:0.057888 	 Loss_bc_1:0.025469 	 Loss_bc_2:0.031850 	 Loss_ic:0.157698
Epoch:400 	 Loss:0.088510 	 Loss_pde:0.031389 	 Loss_bc_1:0.006226 	 Loss_bc_2:0.009535 	 Loss_ic:0.041360
Epoch:500 	 Loss:0.049348 	 Loss_pde:0.018090 	 Loss_bc_1:0.004248 	 Loss_bc_2:0.005654 	 Loss_ic:0.021357
Epoch:600 	 Loss:0.031040 	 Loss_pde:0.012227 	 Loss_bc_1:0.002239 	 Loss_bc_2:0.003937 	 Loss_ic:0.012637
Epoch:700 	 Loss:0.025364 	 Loss_pde:0.009984 	 Loss_bc_1:0.001569 	 Loss_bc_2:0.003844 	 Loss_ic:0.009968
Epoch:800 	 Loss:0.021101 	 Loss_pde:0.008605 	 Loss_bc_1:0.001073 	 Loss_bc_2:0.001800 	 Loss_ic:0.009622
Epoch:900 	 Loss:0.018566 	 Loss_pde:0.007755 	 Loss_bc_1:0.001186 	 Loss_bc_2:0.001689 	 Loss_ic:0.007935
Epoch:1000 	 Loss:0.014207 	 Loss_pde

In [260]:
net.eval()
test_data = torch.as_tensor(X,dtype=torch.float32).to('cuda:0')
test_label = torch.as_tensor(y,dtype=torch.float32).to('cuda:0')

In [261]:
with torch.no_grad():
    test_pred = net(test_data)

In [262]:
results = relative_l2_error(test_label, test_pred)

In [263]:
results

tensor(0.3188, device='cuda:0')

In [264]:
net.train()

VanillaPDETransformer(
  (dpo): Dropout(p=0.05, inplace=False)
  (feat_extract): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): SiLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): SiLU()
  )
  (encoder_layers): ModuleList(
    (0): SimpleTransformerEncoderLayer(
      (attn): SimpleAttention(
        (linears): ModuleList(
          (0-2): 3 x Linear(in_features=64, out_features=64, bias=True)
        )
        (norm_K): ModuleList(
          (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (norm_Q): ModuleList(
          (0-3): 4 x LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (fc): Linear(in_features=72, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (lr1): Linear(in_features=64, out_features=64, bias=True)
        (activation

In [265]:
optimizer_lbfgs = torch.optim.LBFGS(
            net.parameters(), 
            lr=1.0, 
            max_iter=10000, 
            max_eval=None, 
            history_size=100,
            tolerance_grad=1e-5, 
            tolerance_change=1.0 * np.finfo(float).eps,
            line_search_fn="strong_wolfe"       # can be "strong_wolfe"
        )

In [266]:
def closure():
    optimizer_lbfgs.zero_grad()
    input_data = all_data.to('cuda:0') 
    input_data.requires_grad=True
    output = net(input_data)
    pde_results = pde(input_data, output)
    loss_res = MSE(pde_results, torch.zeros_like(pde_results))
    
    bc_error_1 = bc_1.error(input_data[begin_indices[1]:end_indices[1]],output[begin_indices[1]:end_indices[1]])
    loss_bc_1 = MSE(bc_error_1, torch.zeros_like(bc_error_1))
    
    bc_error_2 = bc_2.error(input_data[begin_indices[2]:end_indices[2]],output[begin_indices[2]:end_indices[2]])
    loss_bc_2 = MSE(bc_error_2, torch.zeros_like(bc_error_2))
    
    ic_error = ic.error(input_data[begin_indices[3]:end_indices[3]],output[begin_indices[3]:end_indices[3]])
    loss_ic = MSE(ic_error, torch.zeros_like(ic_error))
    
    loss= loss_res+loss_bc_1+loss_bc_2+loss_ic
    if i%100==0:
        print('Epoch:%d \t Loss:%.6f \t Loss_pde:%.6f \t Loss_bc_1:%.6f \t Loss_bc_2:%.6f \t Loss_ic:%.6f'%(i,loss.detach().cpu().numpy(), loss_res.detach().cpu().numpy(), loss_bc_1.detach().cpu().numpy(), loss_bc_2.detach().cpu().numpy(), loss_ic.detach().cpu().numpy()))
    loss.backward()
    return loss

In [267]:
for i in range(1, 10000+1):
    optimizer_lbfgs.step(closure)

Epoch:100 	 Loss:0.003811 	 Loss_pde:0.001277 	 Loss_bc_1:0.000066 	 Loss_bc_2:0.000027 	 Loss_ic:0.002441
Epoch:100 	 Loss:0.004065 	 Loss_pde:0.001306 	 Loss_bc_1:0.000084 	 Loss_bc_2:0.000030 	 Loss_ic:0.002646
Epoch:100 	 Loss:0.002793 	 Loss_pde:0.001109 	 Loss_bc_1:0.000054 	 Loss_bc_2:0.000028 	 Loss_ic:0.001601
Epoch:100 	 Loss:0.003278 	 Loss_pde:0.001342 	 Loss_bc_1:0.000079 	 Loss_bc_2:0.000023 	 Loss_ic:0.001835
Epoch:100 	 Loss:0.003479 	 Loss_pde:0.001349 	 Loss_bc_1:0.000054 	 Loss_bc_2:0.000027 	 Loss_ic:0.002049
Epoch:200 	 Loss:0.003803 	 Loss_pde:0.001214 	 Loss_bc_1:0.000052 	 Loss_bc_2:0.000028 	 Loss_ic:0.002509
Epoch:200 	 Loss:0.003384 	 Loss_pde:0.001366 	 Loss_bc_1:0.000062 	 Loss_bc_2:0.000032 	 Loss_ic:0.001924
Epoch:200 	 Loss:0.003388 	 Loss_pde:0.001269 	 Loss_bc_1:0.000065 	 Loss_bc_2:0.000030 	 Loss_ic:0.002024
Epoch:200 	 Loss:0.003508 	 Loss_pde:0.001275 	 Loss_bc_1:0.000062 	 Loss_bc_2:0.000026 	 Loss_ic:0.002145
Epoch:200 	 Loss:0.003531 	 Loss_pde:

KeyboardInterrupt: 