In [1]:
import time
import os
import torch
import torch.nn as nn
import numpy as np

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
torch.set_default_dtype(torch.float64)


In [2]:
class ConventBlock(nn.Module):
    def __init__(self,in_N,out_N):
        super(ConventBlock, self).__init__()
        self.Ls  = None
        self.net =nn.Sequential(nn.Linear(in_N,out_N),nn.Tanh()) 

    def forward(self, x):
        out = self.net(x)
        return out 
    

class ModifiedResBlock(nn.Module):
    def __init__(self,in_N,out_N):
        super(ModifiedResBlock, self).__init__()
        self.Ls  = None
        if in_N != out_N:
            self.Ls  = nn.Linear(in_N,out_N)
        self.net = nn.Sequential(nn.Linear(in_N,out_N),nn.Tanh())
    def forward(self, x):
        out = self.net(x)
        if self.Ls : 
            x  = self.Ls(x)
        out = out + x
        return out 

In [3]:
class Network(torch.nn.Module):
    def __init__(self,in_N,m,H_layer,out_N):
        super(Network,self).__init__()
        layers = []
        layers.append(ModifiedResBlock(in_N,m))
        for i in range(0,H_layer-1):
            layers.append(ModifiedResBlock(m,m))
        layers.append(nn.Linear(m,out_N))
        self.net = nn.Sequential(*layers)
        
    def forward(self,x,y):
        data = torch.cat((x,y),dim=1);
        out  = self.net(data)
        return out
        
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight.data)
        nn.init.zeros_(m.bias)

In [4]:
omega = torch.tensor([1.0, 3.0])
def ue(x,y,omega):
    return torch.cos(omega[0]*np.pi*x)*torch.cos(omega[1]*np.pi*y)

def ue_xx(x,y,omega):
    return -(omega[0]*np.pi).pow(2) * ue(x,y,omega)

def ue_yy(x,y,omega):
    return -(omega[1]*np.pi).pow(2) * ue(x,y,omega)

def source_function(x,y,omega):
    return ue_xx(x,y,omega) + ue_yy(x,y,omega)


In [5]:
def fetch_interior_points(N=128):
    phi     = torch.rand(N,1) * 2 * np.pi
    rho     = -1 + torch.cos(phi)*torch.sin(4*phi)
    radius  = rho*torch.sqrt(torch.rand(N,1))
    x       = 0.55*radius*torch.cos(phi)
    y       = 0.75*radius*torch.sin(phi)
    return x,y

def fetch_boundary_points(N=32):
    phi     = torch.rand(N,1) * 2 * np.pi
    rho     = -1 + torch.cos(phi)*torch.sin(4*phi)
    radius  = rho
    x       = 0.55*radius*torch.cos(phi)
    y       = 0.75*radius*torch.sin(phi)
    return x,y
    

In [6]:
def physics_loss(model,x,y,omega):
    u        = model(x,y)
    u_x,u_y  = torch.autograd.grad(u.sum(),(x,y),create_graph=True)
    u_xx     = torch.autograd.grad(u_x.sum(),x,create_graph=True)[0]
    u_yy     = torch.autograd.grad(u_y.sum(),y,create_graph=True)[0]
    
    rhs      = source_function(x,y,omega)
    lhs      = u_xx + u_yy
    pde_loss = (rhs-lhs).pow(2)
    return pde_loss

In [7]:
def boundary_loss(model,x,y,omega):
    u       = model(x,y)
    u_      = ue(x,y,omega)
    e       = (u - u_) 
    bc_loss = e.pow(2)
    return bc_loss

In [8]:
model = Network(in_N=2,m=50,H_layer=3,out_N=1)
model.apply(init_weights)
print(model)

Network(
  (net): Sequential(
    (0): ModifiedResBlock(
      (Ls): Linear(in_features=2, out_features=50, bias=True)
      (net): Sequential(
        (0): Linear(in_features=2, out_features=50, bias=True)
        (1): Tanh()
      )
    )
    (1): ModifiedResBlock(
      (net): Sequential(
        (0): Linear(in_features=50, out_features=50, bias=True)
        (1): Tanh()
      )
    )
    (2): ModifiedResBlock(
      (net): Sequential(
        (0): Linear(in_features=50, out_features=50, bias=True)
        (1): Tanh()
      )
    )
    (3): Linear(in_features=50, out_features=1, bias=True)
  )
)


In [9]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)

5451


In [10]:
def evaluate(model):
    model.eval()
    x_test,y_test = fetch_interior_points(100000)
    u_star  = ue(x_test,y_test,omega)   
    u_pred  = model(x_test,y_test).detach()
    l2   = np.linalg.norm(u_star- u_pred.detach(), 2)/np.linalg.norm(u_star, 2)
    linf = max(abs(u_star- u_pred.detach().numpy()))
    return l2,linf.item()

In [None]:
epochs          = 25000
disp            = 5000

print_to_consol = True

N_dom           = 512

N_bc            = 512

model.apply(init_weights)

# update the optimizer
optimizer  = Adam(model.parameters(),lr=1e-2)

# update the learning rate scheduler 
scheduler = ReduceLROnPlateau(optimizer,patience=100,factor=0.95)

# initialize penalty parameter
mu           = torch.tensor(1.0)

# maximum penalty value for safeguarding
mu_max      = torch.tensor(1e4)

# l2 norm of constraints |C|_2
eta          = torch.tensor(0.0)

# penalty tolerance
epsilon      = torch.tensor(1e-8)

# generate boundary points 
x_bc,y_bc  = fetch_boundary_points(N_bc)

# lagrange multipliers 
lambda_bc  = torch.zeros_like(x_bc)

for epoch in tqdm(range(1,epochs+1)):
    
    optimizer.zero_grad()
    # generating collocation points
    x_dm,y_dm      = fetch_interior_points(N_dom)
    x_dm           = x_dm.requires_grad_(True)
    y_dm           = y_dm.requires_grad_(True)

    pde_loss       = physics_loss(model,x_dm,y_dm,omega)

    bc_constraints = boundary_loss(model,x_bc,y_bc,omega)
    bc_loss        = (lambda_bc*bc_constraints).sum()
    penalty        = bc_constraints.pow(2).sum()

    loss           =  pde_loss.sum() +  bc_loss + 0.5 * mu * penalty

    loss.backward()
    optimizer.step()
    scheduler.step(loss.item())

    with torch.no_grad():
        if (torch.sqrt(penalty) >= 0.25*eta) and (torch.sqrt(penalty) > epsilon):
            mu      = min(mu*2.0, mu_max)
            lambda_bc  += mu * bc_constraints
        eta = torch.sqrt(penalty)
        
        if epoch % disp == 0 and print_to_consol:
            l2,linf = evaluate(model)
            print(f"relative l2 error :{l2:2.3e}, linf error : {linf :2.3e}")
            print(f'epoch : {epoch+1:5d}, penalty loss: {penalty.item():2.3e}')
            
# checkpointing the model 
torch.save(model.state_dict(),f"poisson.pt")
l2,linf = evaluate(model)
print(f"relative l2 error :{l2:2.3e}, linf error : {linf :2.3e}")

 20%|██        | 5006/25000 [01:26<11:19, 29.41it/s]

relative l2 error :6.004e-03, linf error : 1.280e-02
epoch :  5001, penalty loss: 5.417e-09


 40%|████      | 10011/25000 [02:52<07:53, 31.67it/s]

relative l2 error :2.038e-03, linf error : 5.896e-03
epoch : 10001, penalty loss: 7.068e-09


 60%|██████    | 15008/25000 [04:19<05:28, 30.38it/s]

relative l2 error :1.484e-03, linf error : 5.968e-03
epoch : 15001, penalty loss: 3.942e-09


 68%|██████▊   | 17065/25000 [04:53<02:11, 60.23it/s]

In [None]:
model.eval()
x_test,y_test = fetch_interior_points(200000)
u_star  = ue(x_test,y_test,omega)   
u_pred  = model(x_test,y_test).detach()

In [None]:
L2 = np.linalg.norm(u_star- u_pred.detach(), 2)/np.linalg.norm(u_star, 2)
print('Relative l2 error_u: %2.3e' % (L2))

Linf = max(abs(u_star- u_pred.detach()))
print('Relative linf error_u: %2.3e' % (Linf))