In [1]:
import torch 
import torch.nn.functional as F
import numpy as np

torch.set_default_tensor_type('torch.FloatTensor')

m=10
learning_rate=1e-4  
iterations=2000  #default 10000
print_every_iter=100
beta=100 #coefficient for the regularization term in the loss expression, is set to be 1000 in section 3.1
n1=1000 #number of points in (0,1)^m
n2=100  #number of points on the border of (0,1)^m
n3=100  #number of points used for evaluating the error

class DeepRitzNet(torch.nn.Module):
    def __init__(self, m):
        super(DeepRitzNet, self).__init__()
        self.linear1=torch.nn.Linear(m,m)
        self.linear2=torch.nn.Linear(m,m)
        self.linear3=torch.nn.Linear(m,m)
        self.linear4=torch.nn.Linear(m,m)
        self.linear5=torch.nn.Linear(m,m)
        self.linear6=torch.nn.Linear(m,m)
        
        self.linear7=torch.nn.Linear(m,1)
    
    def forward(self, x):
        y=(F.relu(self.linear1(x)))**3
        y=(F.relu(self.linear2(x)))**3
        y+=x
        x=y
        y=(F.relu(self.linear3(x)))**3
        y=(F.relu(self.linear4(x)))**3
        y+=x
        x=y
        y=(F.relu(self.linear5(x)))**3
        y=(F.relu(self.linear6(x)))**3
        y+=x
        y=(F.relu(self.linear7(x)))**3
        return y

#U_groundtruth has 10 arguements
def U_groundtruth(*args):
    re = 0
    for i in range(5):
        re += args[2*i]*args[2*i+1]
    return re

#sample a (m,) numpy ndarray on the border of (0,1)^m
def border_sample():
    x=np.random.random_sample((m,))
    #randomly choose an integer between 0 and m-1(including 0 and m-1)
    dim=np.random.random_integers(0,m-1)   
    if np.random.rand()<0.5:
        x[dim]=0
    else:
        x[dim]=1
        
    return x

In [35]:
model = DeepRitzNet(m)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
error_iter=[] #record the error every print_every_iter=100 times

for i in range(iterations):
    #sample n1=1000 points in (0,1)^m
    #x_in_omega is a list consisting n1=1000 tensors , each tensor has the shape of (m,)
    x_in_omega=[] 
    for t in range(n1):
        x_in_omega.append(torch.tensor(np.random.random_sample((m,)),requires_grad=True))
        
    #sample n2=100 points on the border of (0,1)^m
    #x_on_omega is a list consisting n2=100 tensors , each tensor has the shape of (m,)
    x_on_omega=[]
    for t in range(n2):
        x_on_omega.append(torch.tensor(border_sample(),requires_grad=True))
    
    #calculate the loss 
    loss=torch.zeros(1)
    for t in range(n1):
        #if I miss out the ".float()" there will be an error and I don't know why
        #It seems to have something to do with the usage of relu()**3 in DeepRitzNet
        y=model(x_in_omega[t].float())  
        y.backward()
        loss+=0.5*torch.sum((x_in_omega[t].grad.float())**2)
    loss/=n1
    
    regularization=torch.zeros(1)
    for t in range(n2):
        y=model(x_on_omega[t].float())
        #there was a mistake here in the first version
        regularization+=(y-U_groundtruth(*(x_on_omega[t].float())))**2   
    regularization*=beta/n2
    
    loss+=regularization
    
    #and step the optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #print the error
    if((i+1)%print_every_iter==0):
        error=0
        for t in range(n3):
            x_test=torch.tensor(np.random.random_sample((m,)))
            error_instant=abs((model(x_test.float())-U_groundtruth(*(x_test.float()))).item())
            if error_instant>error:
                error = error_instant
        error_iter.append(error)
        print("Error at the",i+1,"th iteration:",error)




Error at the 100 th iteration: 0.6457126140594482


KeyboardInterrupt: 

In [71]:
'''
beta=1000 error fluctuates between 2 and 3
beta=200 error 2~3
beta=100 fluctuate 0.4~0.6 but if we change n3 from 100 to 1000, the error become 2~3
beta=10 fluctuate 0.5~0.7
'''

'\nbeta=1000 error fluctuates between 2 and 3\nbeta=200 error 2~3\nbeta=100 fluctuate 0.4~0.6 but if we change n3 from 100 to 1000, the error become 2~3\nbeta=10 fluctuate 0.5~0.7\n'