## Test on a simple case
Consider the following Possion Equation
$$
\begin{cases}
    \Delta u = 1\qquad &u\in\Omega\\
    u = 0\qquad &u\in\partial\Omega.
\end{cases}$$
Here $\Omega = \{(x, y)|x^2+y^2 < 1\}$

The exact solution to this problem is $$u = \frac{1}{4}(x^2+y^2-1).$$

In [144]:
% matplotlib inline
import torch 
import torch.nn.functional as F
import numpy as np
from math import *
import matplotlib.pyplot as plt
import matplotlib.cm as cm

torch.set_default_tensor_type('torch.FloatTensor')

m = 10
learning_rate = 1
iterations = 8000  #default 10000
print_every_iter = 100
beta = 10 #coefficient for the regularization term in the loss expression, is set to be 1000 in section 3.1
n1 = 1000 #number of points in (0,1)^m
n2 = 100  #number of points on the border of (0,1)^m
n3 = 100  #number of points used for evaluating the error

class DeepRitzNet(torch.nn.Module):
    def __init__(self, m):
        super(DeepRitzNet, self).__init__()
        self.linear1 = torch.nn.Linear(m,m)
        self.linear2 = torch.nn.Linear(m,m)
        self.linear3 = torch.nn.Linear(m,m)
        self.linear4 = torch.nn.Linear(m,m)
        self.linear5 = torch.nn.Linear(m,m)
        self.linear6 = torch.nn.Linear(m,m)
        
        self.linear7 = torch.nn.Linear(m,1)
    
    def forward(self, x):
        y = (F.relu(self.linear1(x))) ** 3
        y = (F.relu(self.linear2(x))) ** 3
        y += x
        x = y
        y = (F.relu(self.linear3(x))) ** 3
        y = (F.relu(self.linear4(x))) ** 3
        y += x
        x = y
        y = (F.relu(self.linear5(x))) ** 3
        y = (F.relu(self.linear6(x))) ** 3
        y += x
        y = (F.relu(self.linear7(x))) ** 3
        return y

In [None]:
def draw_graph():
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)        
            z[i, j] = model(re.float()).item() - U_groundtruth(re)
    
    plt.imshow(z, cmap=cm.hot)
    plt.colorbar()
    plt.show()

In [146]:
#U_groundtruth = 1/4*(x^2+y^2)-1/4
#take in a (m,) tensor (x, y, ...)
def U_groundtruth(t):
    re = (t[0] ** 2 + t[1] ** 2 - 1).item() / 4
    return re

#turn a (2,) tensor/ndarray to a (m,) tensor
def zeropad(x_2, m):
    x_10 = torch.zeros(m, )
    x_10[0] = x_2[0]
    x_10[1] = x_2[1]
    return x_10
    
#sample a (m,) tensor on the border of the unit circle
def on_sample(m):
    theta = np.random.rand() * 2 * pi
    re = np.zeros(m)
    re[0] = math.cos(theta)
    re[1] = math.sin(theta)
    re = torch.tensor(re, requires_grad=True)
    return re

#sample a (m,) tensor in the unit circle
def in_sample(m):
    r = sqrt(np.random.rand())
    theta = np.random.rand() * 2 * pi
    re = np.zeros(m)
    re[0] = r * math.cos(theta)
    re[1] = r * math.sin(theta)
    re = torch.tensor(re, requires_grad=True)
    return re

In [None]:
model = DeepRitzNet(m)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
in_error_iter = [] #record the error in Omega every print_every_iter=100 times
on_error_iter = [] #record the error on the border of Omega every print_every_iter=100 times

for i in range(iterations):
    #calculate the loss 
    loss = torch.zeros(1)
    for t in range(n1):
        #if I miss out the ".float()" there will be an error and I don't know why
        #It seems to have something to do with the usage of relu()**3 in DeepRitzNet
        x_input = in_sample(m)
        y = model(x_input.float())
        #there will be an error without "retain_graph=True" , I don't know why
        y.backward(retain_graph = True)
        
        loss += 0.5 * ((x_input.grad.float()[0]) ** 2 + (x_input.grad.float()[1]) ** 2) - y
    loss /= n1
    print(loss)
    
    regularization = torch.zeros(1)
    for t in range(n2):
        x_input = on_sample(m).float()
        y = model(x_input)
        regularization += y ** 2   
    regularization *= beta / n2
    
    print(regularization)
    
    loss += regularization
    
    draw_graph()
    
    #and step the optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    draw_graph()
    
        
print("Traning Completed.")

In [None]:
#print the error
    if((i+1) % print_every_iter == 0):
        in_error = 0
        on_error = 0
        
        for t in range(n3):
            in_x_test = in_sample(m)
            in_error_instant = abs((model(in_x_test.float()) -
                                    U_groundtruth(in_x_test.float())).item())
            in_error = max(in_error, in_error_instant)
            
            on_x_test = on_sample(m)
            on_error_instant = abs((model(on_x_test.float()) -
                                    U_groundtruth(on_x_test.float())).item())
            on_error = max(on_error,on_error_instant)
            
        in_error_iter.append(in_error)
        on_error_iter.append(on_error)
        
        print("Error in Omega at the",i+1,"th iteration:",in_error)
        print("Error on the border of Omega at the",i+1,"th iteration:",on_error)

In [None]:
for i in range(iterations):
    #calculate the loss 
    loss=torch.zeros(1)
    for t in range(n1):
        #if I miss out the ".float()" there will be an error and I don't know why
        #It seems to have something to do with the usage of relu()**3 in DeepRitzNet
        x_input=in_sample()
        y=model(x_input.float())
        #there will be an error without "retain_graph=True" , I don't know why
        y.backward(retain_graph=True)
        loss+=0.5*((x_input.grad.float()[0])**2+(x_input.grad.float()[1])**2)-y
    loss/=n1
    
    regularization=torch.zeros(1)
    for t in range(n2):
        x_input=on_sample().float()
        y=model(x_input)
        regularization+=y**2   
    regularization*=beta/n2
    
    loss+=regularization
    
    #and step the optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #print the error
    if((i+1)%print_every_iter==0):
        in_error=0
        on_error=0
        
        for t in range(n3):
            in_x_test=in_sample()
            in_error_instant=abs((model(in_x_test.float())-U_groundtruth(in_x_test.float())).item())
            in_error=max(in_error,in_error_instant)
            
            on_x_test=on_sample()
            on_error_instant=abs((model(on_x_test.float())-U_groundtruth(on_x_test.float())).item())
            on_error=max(on_error,on_error_instant)
            
        in_error_iter.append(in_error)
        on_error_iter.append(on_error)
        
        print("Loss at the",i+1,"th iteration:", loss)
        print("Error in Omega at the",i+1,"th iteration:",in_error)
        print("Error on the border of Omega at the",i+1,"th iteration:",on_error)
        
print("Traning Completed.")