In [1]:
import torch
from torch.nn import Linear
from torch.autograd import Variable
from torch import Tensor
import random
import math
import numpy as np
from scipy.stats import ortho_group
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import os
from google.colab import drive
from google.colab import files
import shutil
drive.mount('/content/drive/')
filePath = '/content/drive/MyDrive/sphereerrorana3'

Mounted at /content/drive/


In [2]:
def set_seed_everywhere(seed):
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  random.seed(seed)
set_seed_everywhere(111111)

In [3]:
###### sample by numpy, output tensor
def Sample_sphereunif(batch,Dim):
  ###### uniform sampling on the sphere
  mean = np.zeros(Dim)
  cov = np.eye(Dim)
  u0 = np.random.multivariate_normal(mean,cov,batch)
  u1 = u0 * u0
  u2 = np.sqrt(u1.sum(1))
  u3 = u2.reshape([batch,1])
  u4 = u3.repeat(Dim,1)
  u5 = u0 / u4
  tensor_u0 = torch.from_numpy(u5)
  tensor_u1 = tensor_u0.float()
  return tensor_u1

def Sample_batch(batch,Dim):
  ###### Dim = 3
  part1batch = int(batch)
  part2batch = int(0*batch)
  part1Xs = Sample_sphereunif(part1batch,Dim)
  part1Xr = Sample_sphereunif(part1batch,Dim)
  part2counterXs = Sample_sphereunif(part2batch,Dim)
  part2counterXr = -part2counterXs
  part1Xp = torch.cat((part1Xs,part1Xr),1)
  part2counterXp = torch.cat((part2counterXs,part2counterXr),1)
  Xp = torch.cat((part1Xp,part2counterXp),0)
  return Xp

def EikonalLoss(Yobs, Xp, mu, W_s, Dim, batch, device):
    D_mu = torch.autograd.grad(outputs=mu, inputs=Xp, grad_outputs=torch.ones(mu.size()).to(device),
                               only_inputs=True, create_graph=True, retain_graph=True)[0]

    dmu = D_mu[:,Dim:2*Dim]
    batch_dmu = dmu.reshape(batch,Dim,1)


    ##### \boldsymbol{\vec{n}}(x) of \mathcal{S}
    Xr = Xp[:,Dim:2*Dim]
    Xr1 = Xr.reshape([batch,Dim])
    Xr2 = Xr1 * Xr1
    Xr2_norm = torch.sqrt(Xr2.sum(1))
    Xr3 = Xr2_norm.reshape([batch,1])
    Xr4 = Xr3.repeat([1,Dim])
    normal_x0 = Xr1 / Xr4
    normal_x = normal_x0.reshape(batch,Dim,1)

    ## 3rd-tensor
    Proj = torch.eye(Dim).repeat(batch,1,1) - torch.bmm(normal_x, torch.transpose(normal_x,1,2))
    batch_dmu_m = torch.bmm(Proj, batch_dmu)  
    dmu_m = batch_dmu_m.reshape(batch,Dim)


    Xrs = Xp[:,Dim:2*Dim] - Xp[:,0:Dim]
    Xrs2 = Xrs * Xrs
    Xrs2_norm = torch.sqrt(Xrs2.sum(1))
    Xrs3 = Xrs2_norm.reshape([batch,1])
    Xrs4 = Xrs3.repeat([1,Dim])
    normal_Xrs0 = Xrs / Xrs4
    normal_Xrs1 = normal_Xrs0.reshape(batch,1,Dim)
    
    # ###### non-reduction
    # tmp1 = torch.bmm(normal_Xrs1,Proj)
    # normal_Xrs2 = normal_Xrs0.reshape(batch,Dim,1)
    # batch_U10 = torch.bmm(tmp1,normal_Xrs2)
    # U10 = batch_U10.reshape(batch)
    # mu2 = mu[:,0] ** 2
    # U1 = U10 * mu2

    ###### reduction $\big|\big| \nabla_\mathcal{S}^x||x-x_0|| \big|\big|^2 = \frac{1+x_0\cdot x}{2}$
    Xs = Xp[:,0:Dim]
    newtmp0 = Xs * Xr
    newtmp1 = newtmp0.sum(1)
    newU10 = (1 + newtmp1) / 2
    mu2 = mu[:, 0] ** 2
    U1 = newU10 * mu2
    
    # print('U10-newU10 is:\n',U10-newU10)

    batch_dmu_t = dmu.reshape(batch,1,Dim)
    tmp2 = torch.bmm(Proj,batch_dmu)
    batch_U20 = torch.bmm(batch_dmu_t,tmp2)
    U20 = batch_U20.reshape(batch)
    U2 = Xrs2.sum(1) * U20


    batch_Xrs = Xrs.reshape(batch,1,Dim)
    batch_cross = torch.bmm(batch_Xrs,batch_dmu_m)
    cross = batch_cross.reshape(batch)
    U3 = 2 * mu[:,0] * cross
    

    Ypred0 = U1 + U2 + U3
    Ypred = Ypred0 * W_s
    # print('W_s.size() is:',W_s.size())
    # print('Ypred0.size() is:',Ypred0.size())
    # print('Ypred.size() is:',Ypred.size())

    numerator = (Ypred - Yobs) ** 2

    loss = torch.mean(numerator)

    return loss

In [4]:
def init_weights(m):
    if type(m) == torch.nn.Linear:
        stdv = (1. / math.sqrt(m.weight.size(1))/1.)*2
        print('stdv is:\n',stdv)
        m.weight.data.uniform_(-stdv,stdv)
        m.bias.data.uniform_(-stdv,stdv)

class NN(torch.nn.Module):
    def __init__(self, nl=1, activation=torch.nn.ELU()):
        super(NN, self).__init__()
        self.act = activation

        # Input Structure
        self.fc0 = Linear(2*3, 64)
        # self.fc1 = Linear(32, 512)

        # Resnet Block
        self.rn_fc1 = torch.nn.ModuleList([Linear(64,64) for i in range(nl)])
        self.rn_fc2 = torch.nn.ModuleList([Linear(64,64) for i in range(nl)])
        self.rn_fc3 = torch.nn.ModuleList([Linear(64,64) for i in range(nl)])

        # Output structure
        self.fc4 = Linear(64, 1)
        # self.fc8 = Linear(512, 32)
        # self.fc9 = Linear(32, 1)

    def forward(self, x):
        x = self.act(self.fc0(x))
        # x = self.act(self.fc1(x))
        for ii in range(len(self.rn_fc1)):
            x0 = x
            x = self.act(self.rn_fc1[ii](x))
            x = self.act(self.rn_fc3[ii](x) + self.rn_fc2[ii](x0))

        # x = self.act(self.fc8(x))
        mu = abs(self.fc4(x))
        return mu

class Model():
    print('Enter Model!')
    def __init__(self, ModelPath, device='cpu'):
        self.Params = {}
        self.Params['ModelPath'] = ModelPath
        self.Params['Device'] = device
        self.Params['Pytorch Amp (bool)'] = False

        self.Params['Network'] = {}
        self.Params['Network']['Dimension'] = 3
        self.Params['Network']['Number of Residual Blocks'] = 1
        self.Params['Network']['Layer activation'] = torch.nn.ELU()

        self.Params['Training'] = {}
        self.Params['Training']['Batch Size'] = 256
        self.Params['Training']['Number of Iterations'] = 200000
        self.Params['Training']['Learning Rate'] = 0.000005
        self.Params['Training']['Use Scheduler (bool)'] = False

    def _init_network(self):
        self.network = NN(nl=self.Params['Network']['Number of Residual Blocks'],activation=self.Params['Network']['Layer activation'])
        self.network.apply(init_weights)
        self.network.float()
        self.network.to(torch.device(self.Params['Device']))

    def save(self,iteration='',distance_error=''):
      torch.save({'iteration':iteration,
                  'model_state_dict':self.network.state_dict(),
                  'optimizer_state_dict':self.optimizer.state_dict(),
                  'train_loss':self.list_train_loss,
                  'distance_error':self.list_distance_error,
                  'rel_distance_error':self.list_rel_distance_error,
                  'distance_error_counterpoint':self.list_distance_error_counterpoint,
                  'rel_distance_error_counterpoint':self.list_rel_distance_error_counterpoint},
                  '{}/Model_Iteration_{}_DistanceError_{}.pt'.format(self.Params['ModelPath'],str(iteration).zfill(5),distance_error))
    
    def load(self,load_filepath):
      self._init_network()
      checkpoint            = torch.load(load_filepath, map_location=torch.device(self.Params['Device']))
      self.network.load_state_dict(checkpoint['model_state_dict'])
      self.network.to(torch.device(self.Params['Device']))    
      
      self.list_train_loss = checkpoint['train_loss']
      self.list_distance_error         = checkpoint['distance_error']
      self.list_rel_distance_error       = checkpoint['rel_distance_error']
      self.list_distance_error_counterpoint         = checkpoint['distance_error_counterpoint']
      self.list_rel_distance_error_counterpoint       = checkpoint['rel_distance_error_counterpoint']


    def train(self):
        print('Enter training!')

        ###### Initialising the network
        self._init_network()

        self.optimizer = torch.optim.Adam(self.network.parameters(),lr=self.Params['Training']['Learning Rate'])
        if self.Params['Training']['Use Scheduler (bool)']:
          self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,milestones=[50000,150000],gamma=0.1)

        self.list_train_loss = []
        self.list_distance_error = []
        self.list_rel_distance_error = []

        self.list_distance_error_counterpoint = []
        self.list_rel_distance_error_counterpoint = []

        for iteration in range(self.Params['Training']['Number of Iterations']):
            
            ####  sample a batch per 1000 steps
            if iteration % 1000 == 0:
              # if (iteration>0) & (iteration%40000==0):
              #   self.Params['Training']['Batch Size'] = 2 * self.Params['Training']['Batch Size']
              Xp = Sample_batch(self.Params['Training']['Batch Size'],self.Params['Network']['Dimension'])


            Xp.requires_grad_()
            self.optimizer.zero_grad()
            
            if self.Params['Pytorch Amp (bool)']:
              with auotocast():
                output = self.network(Xp)
                W_r = torch.ones(self.Params['Training']['Batch Size'])
                W_s = torch.ones(self.Params['Training']['Batch Size'])
                loss = EikonalLoss(W_r, Xp, output, W_s, self.Params['Network']['Dimension'], 
                               self.Params['Training']['Batch Size'], torch.device(self.Params['Device']))
            else:
              output = self.network(Xp)
              W_r = torch.ones(self.Params['Training']['Batch Size'])
              W_s = torch.ones(self.Params['Training']['Batch Size'])
              loss = EikonalLoss(W_r, Xp, output, W_s, self.Params['Network']['Dimension'], 
                               self.Params['Training']['Batch Size'], torch.device(self.Params['Device']))


            loss.backward()
            self.optimizer.step()

            train_loss = loss.item()
            self.list_train_loss.append(train_loss)

            if self.Params['Training']['Use Scheduler (bool)']:
              self.scheduler.step()
            

           
            if iteration % 1000 == 0:

              ######### 指定起始点比较
              counterpoint_myXs0 = Tensor([[1,0,0]])
              counterpoint_norm_myXs0 = torch.sqrt(counterpoint_myXs0 @ counterpoint_myXs0.transpose(1,0))
              counterpoint_myXs1 = counterpoint_norm_myXs0.repeat([1,self.Params['Network']['Dimension']])
              counterpoint_myXs2 = counterpoint_myXs0 / counterpoint_myXs1
              counterpoint_myXs = counterpoint_myXs2.reshape([self.Params['Network']['Dimension']])
              # print('counterpoint_myXs is:',counterpoint_myXs)

              counterpoint_myXr0 = Tensor([[-1,0,0]])
              counterpoint_norm_myXr0 = torch.sqrt(counterpoint_myXr0 @ counterpoint_myXr0.transpose(1,0)) 
              counterpoint_myXr1 = counterpoint_norm_myXr0.repeat([1,self.Params['Network']['Dimension']])
              counterpoint_myXr2 = counterpoint_myXr0 / counterpoint_myXr1
              counterpoint_myXr = counterpoint_myXr2.reshape([self.Params['Network']['Dimension']])
              # print('counterpoint_myXr is:',counterpoint_myXr)
              counterpoint_myXp = torch.cat((counterpoint_myXs,counterpoint_myXr),0)
 
              ## Check!!! Counterpoint
              counterpoint_myXs2 = counterpoint_myXs * counterpoint_myXs
              counterpoint_norm_myXs = torch.sqrt(counterpoint_myXs2.sum(0))
              counterpoint_myXr2 = counterpoint_myXr * counterpoint_myXr
              counterpoint_norm_myXr = torch.sqrt(counterpoint_myXr2.sum(0))
              counterpoint_denominator = counterpoint_norm_myXs * counterpoint_norm_myXr
              counterpoint_numerator = counterpoint_myXs @ counterpoint_myXr
              counterpoint_tensor_real_arc = torch.arccos(counterpoint_numerator/counterpoint_denominator)
              counterpoint_real_arc = counterpoint_tensor_real_arc.item()
              print('counterpoint_real_arc is:',counterpoint_real_arc)

              counterpoint_myoutput = self.network(counterpoint_myXp)
              counterpoint_myXrs = counterpoint_myXr - counterpoint_myXs 
              counterpoint_myXrs2 = counterpoint_myXrs * counterpoint_myXrs
              counterpoint_norm_myXrs = torch.sqrt(counterpoint_myXrs2.sum(0))
              counterpoint_tensor_distance = counterpoint_myoutput * counterpoint_norm_myXrs ## 右端项非1时需要再乘以\sqrt(W(x_0))
              counterpoint_distance = counterpoint_tensor_distance.item()
              print('counterpoint_distance is:',counterpoint_distance)
              
              counterpoint_distance_error = counterpoint_real_arc - counterpoint_distance
              print('counterpoint_distance_error is:',counterpoint_distance_error)
              counterpoint_rel_distance_error = abs(counterpoint_distance_error) / counterpoint_real_arc
              print('counterpoint_rel_distance_error is:',counterpoint_rel_distance_error)
              self.list_distance_error_counterpoint.append(counterpoint_distance_error)
              self.list_rel_distance_error_counterpoint.append(counterpoint_rel_distance_error)



              ######### 随机取两点比较     
              myXs0 = Sample_sphereunif(1,self.Params['Network']['Dimension'])
              myXs = myXs0.reshape(self.Params['Network']['Dimension'])
              myXr0 = Sample_sphereunif(1,self.Params['Network']['Dimension'])
              myXr = myXr0.reshape(self.Params['Network']['Dimension'])
              myXp = torch.cat((myXs,myXr),0)
              # print('myXs is:',myXs)
              # print('myXr is:',myXr)

              ## Check!!! Random point
              myXs2 = myXs * myXs
              norm_myXs = torch.sqrt(myXs2.sum(0))
              myXr2 = myXr * myXr
              norm_myXr = torch.sqrt(myXr2.sum(0))
              denominator = norm_myXs * norm_myXr
              numerator = myXs @ myXr
              tensor_real_arc = torch.arccos(numerator/denominator)
              real_arc = tensor_real_arc.item()
              print('real_arc is:',real_arc)

              myoutput = self.network(myXp)
              myXrs = myXr - myXs 
              myXrs2 = myXrs * myXrs
              norm_myXrs = torch.sqrt(myXrs2.sum(0))
              tensor_distance = myoutput * norm_myXrs ## 右端项非1时需要再乘以\sqrt(W(x_0))
              distance = tensor_distance.item()
              print('distance is:',distance)              
              
              distance_error = real_arc - distance
              print('distance_error is:',distance_error)
              rel_distance_error = abs(distance_error) / real_arc
              print('rel_distance_error is:',rel_distance_error)
              self.list_distance_error.append(distance_error)
              self.list_rel_distance_error.append(rel_distance_error)

              print('learning rate is:',self.optimizer.state_dict()['param_groups'][0]['lr'])

              print('batch size is:',self.Params['Training']['Batch Size'])

              with torch.no_grad():
                print("iteration = {} -- Training loss = {:.4e} ".format(iteration+1, train_loss))
            
            if iteration == self.Params['Training']['Number of Iterations'] - 1:
              with torch.no_grad():
                self.save(iteration=iteration,distance_error=distance_error)
        
         
        return self.list_train_loss, self.list_distance_error, self.list_rel_distance_error, self.list_distance_error_counterpoint, self.list_rel_distance_error_counterpoint

Enter Model!


In [7]:
def real_dis(xs,xr):
  xs2 = xs * xs
  norm_xs = torch.sqrt(xs2.sum(0))
  xr2 = xr * xr
  norm_xr = torch.sqrt(xr2.sum(0))
  denominator = norm_xs * norm_xr
  numerator = xs @ xr
  tensor_real_arc = torch.arccos(numerator/denominator)
  real_arc = tensor_real_arc.item()
  # print('real_arc is:',real_arc)
  return real_arc

def check(load_filepath,xs,xr,batch):
  ## xs, xr should be one dimension. xp = torch.cat((xs,xr),1) if they are two dimension. 
  checkpoint = torch.load(load_filepath)
  checkNN = NN(nl=1,activation=torch.nn.ELU())
  checkNN.load_state_dict(checkpoint['model_state_dict'])
  xp = torch.cat((xs,xr),1)
  check_output0 = checkNN(xp)
  # print('check_output0 is:\n',check_output0)
  check_output = check_output0.reshape(batch)
  # print('check_output.size() is:\n',check_output.size())
  xsr = xr - xs
  xsr2 = xsr * xsr
  # print('xsr2 is:\n',xsr2)
  norm_xsr = torch.sqrt(xsr2.sum(1))
  # print('norm_xsr is:\n',norm_xsr)
  tensor_check_distance = norm_xsr * check_output
  # check_distance = tensor_check_distance.item()
  return tensor_check_distance

In [8]:
def Samplecircle(batch,Dim,theta,R):
  fix = np.array([[0.0,0.0,1.0]])
  fix_batch = np.repeat(fix,batch,0)
  ###### rcalength = R * theta
  real_arc = R * theta
  ######
  r = R * np.sin(theta)
  z0 = R * np.cos(theta)
  z1 = z0.reshape([1,1])
  z_batch = np.repeat(z1,batch,0)
  uc0 = Sample_sphereunif(batch,Dim-1)
  uc1 = r * uc0
  s_batch = np.hstack((uc1,z_batch))
  # tensor_real_arc = torch.from_numpy(real_arc)
  # tensor_fix_batch = torch.from_numpy(fix_batch)
  # tensor_s_batch = torch.from_numpy(s_batch)
  # return real_arc, tensor_fix_batch.float(), tensor_s_batch.float()
  return real_arc, fix_batch, s_batch

def orthosample(fix_batch0,s_batch0,batch,Dim):
  fix_batch1 = fix_batch0.reshape(batch,1,Dim)
  s_batch1 = s_batch0.reshape(batch,1,Dim)
  ####### generate random orthogonal matrix
  orth0 = ortho_group.rvs(Dim)
  orth1 = orth0.reshape(1,Dim,Dim)
  orth_batch = orth1.repeat(batch,0)
  tmp_fix = np.einsum('ijk,ink->ijn',orth_batch,fix_batch1)
  tmp_s = np.einsum('ijk,ink->ijn',orth_batch,s_batch1)
  fix_batch = tmp_fix.reshape(batch,Dim)
  s_batch = tmp_s.reshape(batch,Dim)
  return s_batch, fix_batch
  
def error_distance(load_filepath,N,M,batch,Dim):
  # theta = []
  list_rel_error = []
  theta = np.linspace(0,np.pi,N+1)
  for i in range(1,N+1):
    # theta.append(i*np.pi/N)
    # print('current theta is:',i*np.pi/N)
    # real_arc_i, fix_batch0_i, s_batch0_i = Samplecircle(batch,Dim,i*np.pi/N,1)
    real_arc_i, fix_batch0_i, s_batch0_i = Samplecircle(batch,Dim,theta[i],1)
    ###### expectation on theta[i]
    list_rel_error_i = []
    for _ in range(M):
      spin_s_batch0, spin_fix_batch0 = orthosample(fix_batch0_i,s_batch0_i,batch,Dim)
      ###### convert array to tensor
      spin_fix_batch1 = torch.from_numpy(spin_fix_batch0)
      spin_fix_batch = spin_fix_batch1.float()
      spin_s_batch1 = torch.from_numpy(spin_s_batch0)
      spin_s_batch = spin_s_batch1.float()
      check_distance = check(load_filepath,spin_fix_batch,spin_s_batch,batch)
      abs_error_i = check_distance - real_arc_i
      rel_error_i = abs(abs_error_i) / real_arc_i
      # print('theta[i] is:',i*np.pi/N)
      # print('rel_error_i is:',rel_error_i)
      mean_rel_error_i = torch.mean(rel_error_i)
      # print('mean_rel_error_i is:',mean_rel_error_i)
      list_rel_error_i.append(mean_rel_error_i.item())
    print('list_rel_error_i is:',list_rel_error_i)
    array_rel_error_i = np.array(list_rel_error_i)
    list_rel_error.append(array_rel_error_i.mean())
  return theta, list_rel_error

随机种子数：111111

球面上均匀采样，batch size:64-1024

固定学习率：$10^{-5}$



In [None]:
model_sphereerrorana3 = Model(filePath)
list_train_loss_sphereerrorana3, list_distance_error_sphereerrorana3, list_rel_distance_error_sphereerrorana3, list_distance_error_counterpoint_sphereerrorana3, list_rel_distance_error_counterpoint_sphereerrorana3 = model_sphereerrorana3.train()

Enter training!
stdv is:
 0.8164965809277261
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 0.10789135098457336
counterpoint_distance_error is: 3.033701390028
counterpoint_rel_distance_error is: 0.9656571173035629
real_arc is: 0.9089365005493164
distance is: 0.10896614193916321
distance_error is: 0.7999703586101532
rel_distance_error is: 0.880116881791732
learning rate is: 1e-05
batch size is: 64
iteration = 1 -- Training loss = 1.5981e+00 
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 0.8643876910209656
counterpoint_distance_error is: 2.2772050499916077
counterpoint_rel_distance_error is: 0.7248568601090022
real_arc is: 2.4680941104888916
distance is: 0.4921039342880249
distance_error is: 1.9759901762008667
rel_distance_error is: 0.8006137885112709
learning rate is: 1e-05
batch size is: 64
iteration = 1001 -- Training loss = 3.1542e-01 
counterpoint_real_arc is: 3.1415927410125

KeyboardInterrupt: ignored

随机种子数：111111

球面上均匀采样，batch size:64-1024

固定学习率：$5\times10^{-6}$

In [None]:
model_sphereerrorana3 = Model(filePath)
list_train_loss_sphereerrorana3, list_distance_error_sphereerrorana3, list_rel_distance_error_sphereerrorana3, list_distance_error_counterpoint_sphereerrorana3, list_rel_distance_error_counterpoint_sphereerrorana3 = model_sphereerrorana3.train()

Enter training!
stdv is:
 0.8164965809277261
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 0.9526342749595642
counterpoint_distance_error is: 2.188958466053009
counterpoint_rel_distance_error is: 0.6967671008010673
real_arc is: 2.171769618988037
distance is: 0.3773156702518463
distance_error is: 1.7944539487361908
rel_distance_error is: 0.8262634917843352
learning rate is: 5e-06
batch size is: 64
iteration = 1 -- Training loss = 1.1279e+01 
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 0.3252789378166199
counterpoint_distance_error is: 2.8163138031959534
counterpoint_rel_distance_error is: 0.8964605012068565
real_arc is: 1.469959020614624
distance is: 0.20008686184883118
distance_error is: 1.2698721587657928
rel_distance_error is: 0.8638826939779789
learning rate is: 5e-06
batch size is: 64
iteration = 1001 -- Training loss = 7.3682e-01 
counterpoint_real_arc is: 3.141592741012

In [None]:
load_filepath = '/content/drive/MyDrive/sphereerrorana3/Model_Iteration_199999_DistanceError_0.0666954517364502.pt'
N = 150
M = 60
batch = 100
Dim = 3
theta, list_rel_error = error_distance(load_filepath,N,M,batch,Dim)

list_rel_error_i is: [0.0058084409683942795, 0.007094109430909157, 0.0058263689279556274, 0.002776046749204397, 0.02750350534915924, 0.0029526809230446815, 0.008134596049785614, 0.007174130063503981, 0.004611421842128038, 0.002492983592674136, 0.0094780707731843, 0.012037665583193302, 0.004965643398463726, 0.0034271853510290384, 0.023673854768276215, 0.010883600451052189, 0.00250048004090786, 0.002448043320327997, 0.009861936792731285, 0.008434864692389965, 0.005757542792707682, 0.01565389335155487, 0.0019106672843918204, 0.01551175955682993, 0.002441379241645336, 0.013580615632236004, 0.0010683794971555471, 0.02925945073366165, 0.011442464776337147, 0.004952193703502417, 0.0005386546836234629, 0.011230566538870335, 0.018194174394011497, 0.007370664272457361, 0.007321793586015701, 0.0009288273868151009, 0.01903849095106125, 0.004364971071481705, 0.002611086005344987, 0.0005411056918092072, 0.016691939905285835, 0.001987261464819312, 0.009258931502699852, 0.004448130261152983, 0.0008382

In [None]:
###### N=150, M=60
fig = go.Figure()
fig.add_trace(go.Scatter(x=theta[1:],y=list_rel_error,name="rel_Distance_Error"))
fig.update_layout(width=700,height=480,template="plotly_white",margin=dict(l=5,r=5,t=5,b=5))
fig.show()

随机种子数：111111

球面上均匀采样，batch size:256

固定学习率：$5\times10^{-6}$

In [None]:
model_sphereerrorana3 = Model(filePath)
list_train_loss_sphereerrorana3, list_distance_error_sphereerrorana3, list_rel_distance_error_sphereerrorana3, list_distance_error_counterpoint_sphereerrorana3, list_rel_distance_error_counterpoint_sphereerrorana3 = model_sphereerrorana3.train()

Enter training!
stdv is:
 0.8164965809277261
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 0.5794610977172852
counterpoint_distance_error is: 2.562131643295288
counterpoint_rel_distance_error is: 0.815551809070415
real_arc is: 2.220097541809082
distance is: 2.584193468093872
distance_error is: -0.36409592628479004
rel_distance_error is: 0.1639999682122528
learning rate is: 5e-06
batch size is: 256
iteration = 1 -- Training loss = 5.3532e-01 
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 2.157649040222168
counterpoint_distance_error is: 0.9839437007904053
counterpoint_rel_distance_error is: 0.3131989986942955
real_arc is: 0.9480069875717163
distance is: 0.8305995464324951
distance_error is: 0.11740744113922119
rel_distance_error is: 0.12384659889475695
learning rate is: 5e-06
batch size is: 256
iteration = 1001 -- Training loss = 3.1590e-01 
counterpoint_real_arc is: 3.141592741

In [None]:
load_filepath = '/content/drive/MyDrive/sphereerrorana3/Model_Iteration_199999_DistanceError_0.04263043403625488.pt'
N = 150
M = 60
batch = 100
Dim = 3
theta, list_rel_error = error_distance(load_filepath,N,M,batch,Dim)

list_rel_error_i is: [0.005168708506971598, 0.00521864602342248, 0.04537035897374153, 0.014955861493945122, 0.007649421226233244, 0.01797030121088028, 0.008209235034883022, 0.007823568768799305, 0.015216431580483913, 0.025075137615203857, 0.005500806495547295, 0.008772588334977627, 0.017337104305624962, 0.0007924166857264936, 0.02445007860660553, 0.001978943357244134, 0.025139419361948967, 0.004567363765090704, 0.0028418577276170254, 0.026621118187904358, 0.0011828250717371702, 0.00031278710230253637, 0.03279978409409523, 0.0012828233884647489, 0.03228210657835007, 0.006180645897984505, 0.013717269524931908, 0.013292807154357433, 0.0030163892079144716, 0.003605679841712117, 0.018693720921874046, 0.001478323363699019, 0.001541272271424532, 0.03320923075079918, 0.014822341501712799, 0.01610627956688404, 0.009069555439054966, 0.016342582181096077, 0.003519812598824501, 0.02535787969827652, 0.02673274464905262, 0.01400891412049532, 0.0008684407803229988, 0.004022706300020218, 0.00732371211

In [None]:
###### N=150, M=60
fig = go.Figure()
fig.add_trace(go.Scatter(x=theta[1:],y=list_rel_error,name="rel_Distance_Error"))
fig.update_layout(width=700,height=480,template="plotly_white",margin=dict(l=5,r=5,t=5,b=5))
fig.show()

随机种子数：111111

球面上均匀采样，batch size:256

固定学习率：$5\times10^{-6}$

In [5]:
model_sphereerrorana3 = Model(filePath)
list_train_loss_sphereerrorana3, list_distance_error_sphereerrorana3, list_rel_distance_error_sphereerrorana3, list_distance_error_counterpoint_sphereerrorana3, list_rel_distance_error_counterpoint_sphereerrorana3 = model_sphereerrorana3.train()

Enter training!
stdv is:
 0.8164965809277261
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
stdv is:
 0.25
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 1.8485509157180786
counterpoint_distance_error is: 1.2930418252944946
counterpoint_rel_distance_error is: 0.4115879847868924
real_arc is: 1.4402529001235962
distance is: 1.4284539222717285
distance_error is: 0.011798977851867676
rel_distance_error is: 0.008192295846691327
learning rate is: 5e-06
batch size is: 256
iteration = 1 -- Training loss = 6.1921e-01 
counterpoint_real_arc is: 3.1415927410125732
counterpoint_distance is: 2.388197660446167
counterpoint_distance_error is: 0.7533950805664062
counterpoint_rel_distance_error is: 0.23981309567311324
real_arc is: 1.2399766445159912
distance is: 1.2882639169692993
distance_error is: -0.048287272453308105
rel_distance_error is: 0.038942082229424906
learning rate is: 5e-06
batch size is: 256
iteration = 1001 -- Training loss = 2.9417e-01 
counterpoint_real_arc is: 3

In [9]:
load_filepath = '/content/drive/MyDrive/sphereerrorana3/Model_Iteration_199999_DistanceError_-0.018203258514404297.pt'
N = 150
M = 60
batch = 100
Dim = 3
theta, list_rel_error = error_distance(load_filepath,N,M,batch,Dim)

list_rel_error_i is: [0.012080712243914604, 0.010605258867144585, 0.04197119176387787, 0.007569608744233847, 0.0015190981794148684, 0.01403043046593666, 0.0008160591241903603, 0.00947424117475748, 0.010174624621868134, 0.0013449211837723851, 0.017085213214159012, 0.013559653423726559, 0.02847704105079174, 0.03513307869434357, 0.003047756850719452, 0.01179946307092905, 0.012065482325851917, 0.020257795229554176, 0.02291610650718212, 0.012160268612205982, 0.012152697890996933, 0.0015320009551942348, 0.010023322887718678, 0.004190839361399412, 0.013940073549747467, 0.023838598281145096, 0.020339997485280037, 0.019405972212553024, 0.04125402122735977, 0.004628792870789766, 0.005440558772534132, 0.006505826488137245, 0.012067101895809174, 0.024794163182377815, 0.01061555277556181, 0.021993698552250862, 0.011766479350626469, 0.014116176404058933, 0.015581305138766766, 0.01878369227051735, 0.008519452065229416, 0.023952094838023186, 0.02908434346318245, 0.005128721706569195, 0.001740341191180

In [10]:
###### N=150, M=60
fig = go.Figure()
fig.add_trace(go.Scatter(x=theta[1:],y=list_rel_error,name="rel_Distance_Error"))
fig.update_layout(width=700,height=480,template="plotly_white",margin=dict(l=5,r=5,t=5,b=5))
fig.show()