## Test on a simple case
Consider the following Possion Equation
$$
\begin{cases}
    \Delta u = 1\qquad &u\in\Omega\\
    u = 0\qquad &u\in\partial\Omega.
\end{cases}$$
Here $\Omega = \{(x, y)|x^2+y^2 < 1\}$

The exact solution to this problem is $$u = \frac{1}{4}(x^2+y^2-1).$$

In [1]:
% matplotlib inline
import torch 
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import numpy as np
from math import *
import matplotlib.pyplot as plt
import matplotlib.cm as cm

torch.set_default_tensor_type('torch.FloatTensor')

class DeepRitzNet(torch.nn.Module):
    def __init__(self, m):
        super(DeepRitzNet, self).__init__()
        self.linear1 = torch.nn.Linear(m,m)
        self.linear2 = torch.nn.Linear(m,m)
        self.linear3 = torch.nn.Linear(m,m)
        self.linear4 = torch.nn.Linear(m,m)
        self.linear5 = torch.nn.Linear(m,m)
        self.linear6 = torch.nn.Linear(m,m)
        
        self.linear7 = torch.nn.Linear(m,1)
      
    def forward(self, x):
        y = x
        y = y + F.relu(self.linear2(F.relu(self.linear1(y))))
        y = y + F.relu(self.linear4(F.relu(self.linear3(y))))
        y = y + F.relu(self.linear6(F.relu(self.linear5(y))))
        output = F.relu(self.linear7(y))
        return output

In [2]:
def draw_graph(mod, m):
    points = np.arange(-1, 1, 0.01)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)        
            z[i, j] = mod(re.float()).item() + U_groundtruth(re)
    
    plt.imshow(z, cmap=cm.hot)
    plt.colorbar()
    my_x_ticks = np.arange(-1, 1, 0.2)
    my_y_ticks = np.arange(-1, 1, 0.2)
    ax = plt.gca()
    ax.set_xticks(np.linspace(0,199,9))  
    ax.set_xticklabels(('-1', '-0.75', '-0.5', '-0.25', '0', '0.25', '0.5', '0.75', '1'))  
    ax.set_yticks(np.linspace(0,199,9))  
    ax.set_yticklabels( ('1', '0.75', '0.5', '0.25', '0','-0.25','-0.5','-0.75', '-1'))  
    
    plt.show()

In [4]:
def cal_loss(mod):
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    mmm = 0
    t = 0
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)        
            z[i, j] = mod(re.float()).item() + U_groundtruth(re)
          
            if re[0] ** 2 + re[1] ** 2 < 1 : 
                mmm += abs(z[i, j])
                t += 1
    return mmm / t

In [5]:
def relative_err(mod):
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    w = np.zeros((xl, yl))
    t = 0
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)
            if re[0] ** 2 + re[1] ** 2 < 1 :
                z[i, j] = mod(re.float()).item() + U_groundtruth(re)
                w[i, j] = U_groundtruth(re)
                t += 1
    z = z ** 2
    w = w ** 2
    return np.sum(z) / np.sum(w)

In [6]:
def U_groundtruth(t):
    re = (t[0] ** 2 + t[1] ** 2 - 1).item() / 4
    return re

In [7]:
def validate(mod):
    draw_graph(mod)
    print(cal_loss(mod))

In [8]:
m = 10
learning_rate = 0.01
iterations = 400  
print_every_iter = 100
beta = 500 #coefficient for the regularization term in the loss expression
n2 = 100  #number of points on the border of Omega
gamma = 10

In [None]:
"""
Train with the grid
从初始化模型开始训练
"""
model = DeepRitzNet(m)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
in_error_iter = [] 
on_error_iter = [] 
mm = 1
points = np.arange(-1, 1, 0.1)
xs, ys = np.meshgrid(points, points)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
xl, yl = xs.size()
                
for k in range(iterations):
    n1 = 0
    loss = torch.zeros(1)
    for i in range(xl):
        for j in range(yl):        
            x_input = np.zeros(m)
            x_input[0] = xs[i, j]
            x_input[1] = ys[i, j]
            if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                n1 += 1
                x_input = torch.tensor(x_input).float()
                y = model(x_input)
                
                x1 = torch.zeros(m)
                x2 = torch.zeros(m)
                x1[0] = 0.0001
                x2[1] = 0.0001
                x_input_1 = x_input.float() + x1
                x_input_2 = x_input.float() + x2
                x_input_3 = x_input.float() - x1
                x_input_4 = x_input.float() - x2
                x_input_grad_1 = (model(x_input_1) - y) / 0.0001
                x_input_grad_2 = (model(x_input_2) - y) / 0.0001
                x_input_2_grad_x = (model(x_input_1) + model(x_input_3) - 2 * y) / 0.0001**2
                x_input_2_grad_y = (model(x_input_2) + model(x_input_4) - 2 * y) / 0.0001**2

                loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) - y 
    loss /= n1
    
    regularization = torch.zeros(1)
    for t in range(n2):
        theta = t / n2 * (2 * pi)
        x_input = np.zeros(m)
        x_input[0] = cos(theta)
        x_input[1] = sin(theta)
        x_input = torch.tensor(x_input).float()
        y = model(x_input)
        regularization += y**2 
    regularization *= mm / n2
    if gamma < 500:
        gamma = gamma * 1.01
    if mm < 500:
        mm = mm * 1.01
        
    #print loss
    print(k, " epoch, loss: ", loss.data[0].numpy())
    print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
    print(k, " loss to real solution: ", cal_loss(model))
    if cal_loss(model) < 0.0001:
        break
    
    loss += regularization
    
    optimizer.zero_grad()
    loss.backward()
 
    optimizer.step()
    

In [None]:
from time import *
start = time()              
for k in range(10):
    n1 = 0
    loss = torch.zeros(1)
    for i in range(xl):
        for j in range(yl):        
            x_input = np.zeros(m)
            x_input[0] = xs[i, j]
            x_input[1] = ys[i, j]
            if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                n1 += 1
                x_input = torch.tensor(x_input).float()
                y = model(x_input)
                
                x1 = torch.zeros(m)
                x2 = torch.zeros(m)
                x1[0] = 0.0001
                x2[1] = 0.0001
                x_input_1 = x_input.float() + x1
                x_input_2 = x_input.float() + x2
                x_input_3 = x_input.float() - x1
                x_input_4 = x_input.float() - x2
                x_input_grad_1 = (model(x_input_1) - y) / 0.0001
                x_input_grad_2 = (model(x_input_2) - y) / 0.0001
                x_input_2_grad_x = (model(x_input_1) + model(x_input_3) - 2 * y) / 0.0001**2
                x_input_2_grad_y = (model(x_input_2) + model(x_input_4) - 2 * y) / 0.0001**2

                loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) - y 
    loss /= n1
    
    regularization = torch.zeros(1)
    for t in range(n2):
        theta = t / n2 * (2 * pi)
        x_input = np.zeros(m)
        x_input[0] = cos(theta)
        x_input[1] = sin(theta)
        x_input = torch.tensor(x_input).float()
        y = model(x_input)
        regularization += y**2 
    regularization *= mm / n2
    if gamma < 500:
        gamma = gamma * 1.01
    if mm < 500:
        mm = mm * 1.01
        
    print(k, " epoch, loss: ", loss.data[0].numpy())
    print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
    print(k, " loss to real solution: ", cal_loss(model))
    if cal_loss(model) < 0.0001:
        break
    
    loss += regularization
    
    optimizer.zero_grad()
    loss.backward()
    
    scheduler.step()
    optimizer.step()
stop = time()
print(stop - start)

In [None]:
validate(the_model)

In [None]:
relative_err(the_model)

In [45]:
PATH = 'test_parameters.pkl'
torch.save(the_model.state_dict(), PATH)

In [59]:
m = 10
PATH = 'test_parameters.pkl'
the_model = DeepRitzNet(m)
the_model.load_state_dict(torch.load(PATH))

In [None]:
train(the_model, initial_lr=0.001*learning_rate)

In [62]:
def train(mod, initial_lr=learning_rate, milestones=[400], gamma=0.1, iterations=iterations, mm=1):
    optimizer = torch.optim.Adam(mod.parameters(), lr=initial_lr)
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    mm = 1
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()

    start = time.time()
    for k in range(iterations):
        n1 = 0
        loss = torch.zeros(1)
        for i in range(xl):
            for j in range(yl):        
                x_input = np.zeros(m)
                x_input[0] = xs[i, j]
                x_input[1] = ys[i, j]
                if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                    n1 += 1
                    x_input = torch.tensor(x_input).float()
                    y = mod(x_input)

                    x1 = torch.zeros(m)
                    x2 = torch.zeros(m)
                    x1[0] = 0.0001
                    x2[1] = 0.0001
                    x_input_1 = x_input.float() + x1
                    x_input_2 = x_input.float() + x2
                    x_input_3 = x_input.float() - x1
                    x_input_4 = x_input.float() - x2
                    x_input_grad_1 = (mod(x_input_1) - y) / 0.0001
                    x_input_grad_2 = (mod(x_input_2) - y) / 0.0001
                    x_input_2_grad_x = (mod(x_input_1) + the_model(x_input_3) - 2 * y) / 0.0001**2
                    x_input_2_grad_y = (mod(x_input_2) + the_model(x_input_4) - 2 * y) / 0.0001**2

                    loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) + y
        loss /= n1

        regularization = torch.zeros(1)
        for t in range(n2):
            theta = t / n2 * (2 * pi)
            x_input = np.zeros(m)
            x_input[0] = cos(theta)
            x_input[1] = sin(theta)
            x_input = torch.tensor(x_input).float()
            y = mod(x_input)
            regularization += y**2 
        regularization *= mm / n2
        if gamma < 500:
            gamma = gamma * 1.01
        if mm < 500:
            mm = mm * 1.01

        #print loss
        print(k, " epoch, loss: ", loss.data[0].numpy())
        print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
        print(k, " loss to real solution: ", cal_loss(mod))
        if cal_loss(the_model) < 0.0001:
            break

        loss += regularization

        optimizer.zero_grad()
        loss.backward()

        scheduler.step()
        optimizer.step()
    stop = time.time()
    print(stop - start)