In [6]:
import sys
sys_path = 'D:/Research_CAE/MyTinyCUDANN/tiny-cuda-nn/main'
sys.path.append(sys_path)
from my_tiny_cuda import my_MLP
from tools import random_points_1D,collocation_points_1D
from gradient import grad1, grad2
import torch
import numpy as np
device = 'cuda'

In [7]:
def generate_grid_points(resolution, field_min = 0, field_max = 1):
    x1_list = np.linspace(field_min, field_max, resolution)
    x2_list = np.linspace(field_min, field_max, resolution)
    X1,X2 = np.meshgrid(x1_list,x2_list)
    X_field = torch.tensor(np.concatenate((X1.reshape(-1,1),X2.reshape(-1,1)),
    axis = 1)).float().to(device)
    return X_field

def sample_all_boundary(batch_size_BC,field_min = 0, field_max = 1):
    n00 = torch.tensor([field_min,field_min])
    n01 = torch.tensor([field_min,field_max])
    n10 = torch.tensor([field_max,field_min])
    n11 = torch.tensor([field_max,field_max])
    X_bot = random_points_1D(int(batch_size_BC),n00,n10)
    X_left = random_points_1D(int(batch_size_BC),n00,n01)
    X_right = random_points_1D(int(batch_size_BC),n10,n11)
    X_top = random_points_1D(int(batch_size_BC),n01,n11)
    X_boundaries = torch.cat((X_bot,X_left,X_top,X_right), dim = 0)
    return X_boundaries

class Wave_equation(torch.nn.Module):
    def __init__(self,c = 1.0, L = 1):
        super().__init__()
        self.c = c
        self.c2 = c**2
        self.L = L
    def strong_form(self,x,grad_result):
        d2u_dt2,d2u_dx2 = grad_result["d2u_dx2"],grad_result["d2u_dy2"]
        result = d2u_dt2 - self.c2*d2u_dx2
        return result
    def variational_energy(self,X,u,du_dt,du_dx):
        # result = 1/2*(du_dt**2 +self.k* du_dx**2)
        result = 0.5*(du_dt**2 - self.c2* du_dx**2)
        return result
    def BC_function(self,X):
        # return torch.sin(X[:,0]*torch.pi)*torch.sin(X[:,1]*torch.pi)
        return self.real_solution(X)
    def real_solution(self,X):
        # return torch.sin(X[:,1]*torch.pi/self.L)*torch.cos(self.c*X[:,0]*torch.pi/self.L)
        return torch.sin(X[:,1] + self.c*X[:,0])

In [3]:
model = my_MLP(activation = torch.nn.Tanh(), n_input_dims = 2,
            n_hidden = 4, width = 64,
            spectral_norm = False,dtype = torch.float32).to(device)
equation = Wave_equation()

In [4]:
boundary_sample_points = sample_all_boundary(1000)
f_boundary = equation.BC_function(boundary_sample_points).to(device)
grid_points = generate_grid_points(100)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
optimizer = torch.optim.LBFGS(model.parameters(),lr = 0.5,
                            max_iter = 100,line_search_fn="strong_wolfe")
MSELoss = torch.nn.MSELoss().to(device)

In [5]:
lam = 1.0
n_step = 300
n_step_output = 10
model.train()
diff_info = grad2(model,grid_points.shape[0])
diff_info.to_device(device)
for i in range(n_step):
    def closure():
        global bc_loss, inner_loss
        optimizer.zero_grad()
        bc_loss = MSELoss(model(boundary_sample_points),f_boundary)
        grad_result = diff_info.forward(grid_points)
        PDE_residual = equation.strong_form(grid_points, grad_result)
        inner_loss = (PDE_residual**2).mean()
        loss = inner_loss + lam* bc_loss
        loss.backward()
        return loss
    optimizer.step(closure)
    if i % n_step_output == 0:
        model.eval()
        with torch.no_grad():
            u_error = MSELoss(model(grid_points),equation.real_solution(grid_points))
        print('Iter:',i,'inner_loss:',inner_loss.item(),"\n",
        'bc_loss:',bc_loss.item(),'u_L2:',u_error.item(),)
        model.train()

Iter: 0 inner_loss: 1.7198677596752532e-05 
 bc_loss: 8.61948064994067e-05 u_L2: 8.884025737643242e-05
Iter: 10 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 20 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 30 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 40 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 50 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 60 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 70 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 80 inner_loss: 3.3141293442895403e-07 
 bc_loss: 5.4381217751142685e-08 u_L2: 3.1001215461401443e-07
Iter: 90 inner_loss: 3.3141293442895403e-

In [8]:
boundary_sample_points = sample_all_boundary(1000)
f_boundary = equation.BC_function(boundary_sample_points).to(device)
grid_points = generate_grid_points(100)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
# optimizer = torch.optim.LBFGS(model.parameters(),lr = 0.5,
#                             max_iter = 100,line_search_fn="strong_wolfe")
MSELoss = torch.nn.MSELoss().to(device)

In [9]:
lam = 1000.0
n_step = 3000
n_step_output = 300
n_step_half_lr = 1000
model.train()
diff_info = grad1(model,grid_points.shape[0])
diff_info.to_device(device)
for i in range(n_step):
    def closure():
        global bc_loss, inner_loss
        optimizer.zero_grad()
        bc_loss = MSELoss(model(boundary_sample_points),f_boundary)
        du_dt,du_dx,u = diff_info.forward_2d(grid_points)
        PDE_residual = equation.variational_energy(grid_points,u,du_dt,du_dx)
        inner_loss =  PDE_residual.mean() #(PDE_residual**2).mean()
        loss = inner_loss + lam* bc_loss
        loss.backward()
        return loss
    optimizer.step(closure)
    if i % n_step_output == 0:
        model.eval()
        with torch.no_grad():
            u_error = MSELoss(model(grid_points),equation.real_solution(grid_points))
        print('Iter:',i,'inner_loss:',inner_loss.item(),"\n",
        'bc_loss:',bc_loss.item(),'u_L2:',u_error.item(),)
        model.train()
    
    if i % n_step_half_lr == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr']/2
            current_lr = param_group['lr']
        print("lr:{}".format(current_lr))

Iter: 0 inner_loss: -0.002882980275899172 
 bc_loss: 3.4814283935702406e-06 u_L2: 0.568891704082489
lr:0.0025
Iter: 300 inner_loss: -0.0032196189276874065 
 bc_loss: 1.0966818990709726e-05 u_L2: 6.989463145146146e-05
Iter: 600 inner_loss: -0.003130544675514102 
 bc_loss: 8.737453754292801e-06 u_L2: 6.544514326378703e-05
Iter: 900 inner_loss: -0.0031018853187561035 
 bc_loss: 7.602540335938102e-06 u_L2: 5.8499586884863675e-05
lr:0.00125
Iter: 1200 inner_loss: -0.0030913008376955986 
 bc_loss: 7.047851340757916e-06 u_L2: 5.244095154921524e-05
Iter: 1500 inner_loss: -0.003081482369452715 
 bc_loss: 6.623198260058416e-06 u_L2: 4.697464464697987e-05
Iter: 1800 inner_loss: -0.003069997299462557 
 bc_loss: 6.17291789239971e-06 u_L2: 4.089091089554131e-05
lr:0.000625
Iter: 2100 inner_loss: -0.003059983719140291 
 bc_loss: 5.775395948148798e-06 u_L2: 3.5515848139766604e-05
Iter: 2400 inner_loss: -0.0030537517741322517 
 bc_loss: 5.512651569006266e-06 u_L2: 3.200929859303869e-05
Iter: 2700 inner