In [3]:
import torch 
import torch.nn.functional as F
import numpy as np

torch.set_default_tensor_type('torch.FloatTensor')

m=10
learning_rate=1e-4
iterations=10000
print_every_iter=100
beta=1000 #coefficient for the regularization term in the loss expression
n1=1000 #number of points in (0,1)^m
n2=100  #number of points on the border of (0,1)^m
n3=100  #number of points used for evaluating the error

class DeepRitzNet(torch.nn.Module):
    def __init__(self, m):
        super(DeepRitzNet, self).__init__()
        self.linear1=torch.nn.Linear(m,m)
        self.linear2=torch.nn.Linear(m,m)
        self.linear3=torch.nn.Linear(m,m)
        self.linear4=torch.nn.Linear(m,m)
        self.linear5=torch.nn.Linear(m,m)
        self.linear6=torch.nn.Linear(m,m)
        
        self.linear7=torch.nn.Linear(m,1)
    
    def forward(self, x):
        y=(F.relu(self.linear1(x)))**3
        y=(F.relu(self.linear2(x)))**3
        y+=x
        x=y
        y=(F.relu(self.linear3(x)))**3
        y=(F.relu(self.linear4(x)))**3
        y+=x
        x=y
        y=(F.relu(self.linear5(x)))**3
        y=(F.relu(self.linear6(x)))**3
        y+=x
        y=(F.relu(self.linear7(x)))**3
        return y

#U_groundtruth has 10 arguements
def U_groundtruth(*args):
    re = 0
    for i in range(5):
        re += args[2*i]*args[2*i+1]
    return re

#sample a (m,) numpy ndarray on the border of (0,1)^m
def border_sample():
    x=np.random.random_sample((m,))
    #randomly choose an integer between 0 and m-1(including 0 and m-1)
    dim=np.random.random_integers(0,m-1)   
    if np.random.rand()<0.5:
        x[dim]=0
    else:
        x[dim]=1
        
    return x

In [4]:
model = DeepRitzNet(m)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for i in range(iterations):
    #sample n1=1000 points in (0,1)^m
    #x_in_omega is a list consisting n1=1000 tensors , each tensor has the shape of (m,)
    x_in_omega=[] 
    for t in range(n1):
        x_in_omega.append(torch.tensor(np.random.random_sample((m,)),requires_grad=True))
        
    #sample n2=100 points on the border of (0,1)^m
    #x_on_omega is a list consisting n2=100 tensors , each tensor has the shape of (m,)
    x_on_omega=[]
    for t in range(n2):
        x_on_omega.append(torch.tensor(border_sample(),requires_grad=True))
    
    #calculate the loss 
    loss=torch.zeros(1)
    for t in range(n1):
        #if I miss out the ".float()" there will be an error and I don't know why
        #It seems to have something to do with the usage of relu()**3 in DeepRitzNet
        y=model(x_in_omega[t].float())  
        y.backward()
        loss+=0.5*torch.sum((x_in_omega[t].grad.float())**2)
    loss/=n1
    
    regularization=torch.zeros(1)
    for t in range(n2):
        y=model(x_on_omega[t].float())
        #there was a mistake here in the first version
        regularization+=(y-U_groundtruth(*(x_on_omega[t].float())))**2   
    regularization*=beta/n2
    
    loss+=regularization
    
    #and step the optimizer
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #print the error
    if((i+1)%print_every_iter==0):
        error=0
        for t in range(n3):
            x_test=torch.tensor(np.random.random_sample((m,)))
            error_instant=abs((model(x_test.float())-U_groundtruth(*(x_test.float()))).item())
            if error_instant>error:
                error = error_instant
        print("Error at the",i+1,"th iteration:",error)




Error at the 100 th iteration: 2.555072784423828
Error at the 200 th iteration: 2.506585121154785
Error at the 300 th iteration: 2.4807801246643066
Error at the 400 th iteration: 2.2456953525543213
Error at the 500 th iteration: 2.538088321685791
Error at the 600 th iteration: 2.336458444595337
Error at the 700 th iteration: 2.8706722259521484
Error at the 800 th iteration: 2.4315004348754883
Error at the 900 th iteration: 2.9004335403442383
Error at the 1000 th iteration: 2.4697561264038086
Error at the 1100 th iteration: 2.5612258911132812
Error at the 1200 th iteration: 2.8590900897979736
Error at the 1300 th iteration: 2.2475733757019043
Error at the 1400 th iteration: 2.788788080215454
Error at the 1500 th iteration: 2.559713840484619
Error at the 1600 th iteration: 2.5591330528259277
Error at the 1700 th iteration: 2.691124200820923
Error at the 1800 th iteration: 2.5793068408966064
Error at the 1900 th iteration: 2.494357109069824
Error at the 2000 th iteration: 2.63294410705566

KeyboardInterrupt: 

[1]