In [None]:
import torch
import torch.nn as nn
import argparse
import numpy as np

from script.model import Model
from script.approximator import PCALocalApproximation, GNNLocalApproximation, SurfaceDerivative
from script.auxiliary import *

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from scipy.spatial.distance import pdist

In [None]:
class SimpleNet_TimeIndependent(nn.Module):
    def __init__(self):
        super(SimpleNet_TimeIndependent, self).__init__()
        self.fc0 = nn.Linear(3, 128)
        self.fc1 = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)
        
        nn.init.xavier_uniform_(self.fc0.weight)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        
    def forward(self, x):
        u = torch.sin(torch.pi*self.fc0(x))
        u = torch.sin(torch.pi*self.fc1(u))
        u = torch.sin(torch.pi*self.fc2(u))
        u = self.fc3(u)
        return u

In [None]:
class Args:
    pass

args = Args()

num_pts = int(input('The number of Points : '))
args.num_pts = num_pts ######################

args.surface = 'gnn'
args.gnn_dir = '../section_31/save/trained_GNN_model.pt' ###############
args.K = int(input('K : ')) #####################
args.ratio = 1.0
args.t_batch_size = 10
args.save_dir = 'OpenSurface_{}'.format(args.num_pts)
args.target = 'OpenSurface_{}'.format(args.num_pts)
args.gpu = 0 ###########################

args.lr = 1e-3
args.epochs = 20000
sch_Step_Size = 2000
sch_Gamma = 0.5

In [None]:
domain_type = 'open_surface'

In [None]:
from torch import sin, pi

# def eval_z(x,y):
#     return 0.5*(x*x-y*y)+3/20*sin(args.a*pi*x)*sin(args.b*pi*y)

def eval_u(x,y,z,t):
    return (0.75 - x) * sin(y - z)

def eval_u_point(points):
    x = points[:,0].clone()
    y = points[:,1].clone()
    z = points[:,2].clone()
    return eval_u(x,y,z,0)

def eval_phi(x,y,z):
    return y**2 + z**2 - ((1-0.25*x**2)**0.5 * (0.75*x**2+0.5))**2

#compute true value of Laplace_Beltrami Operator
def laplace_beltrami_torch(eval_u_torch, eval_phi_torch, x,y,z,t):
    u = eval_u_torch(x,y,z,t)
    u_x = compute_grad(u,x)
    u_y = compute_grad(u,y)
    u_z = compute_grad(u,z)
    
    grad_u = torch.stack([u_x,u_y,u_z],1)
    
    u_xx = compute_grad(u_x,x)
    u_xy = compute_grad(u_x,y)
    u_xz = compute_grad(u_x,z)
    
    u_yx = compute_grad(u_y,x)
    u_yy = compute_grad(u_y,y)
    u_yz = compute_grad(u_y,z)
    
    u_zx = compute_grad(u_z,x)
    u_zy = compute_grad(u_z,y)
    u_zz = compute_grad(u_z,z)
    
    lapla_u = (u_xx+u_yy+u_zz).reshape(-1,1)
    
    hess_u = torch.stack([u_xx,u_xy,u_xz,u_yx,u_yy,u_yz,u_zx,u_zy,u_zz],1).reshape(-1,3,3)
    
    phi = eval_phi_torch(x,y,z)
    
    phi_x = compute_grad(phi,x)
    phi_y = compute_grad(phi,y)
    phi_z = compute_grad(phi,z)
    
    phi_xx = compute_grad(phi_x,x)
    phi_xy = compute_grad(phi_x,y)
    phi_xz = compute_grad(phi_x,z)
    
    phi_yx = compute_grad(phi_y,x)
    phi_yy = compute_grad(phi_y,y)
    phi_yz = compute_grad(phi_y,z)
    
    phi_zx = compute_grad(phi_z,x)
    phi_zy = compute_grad(phi_z,y)
    phi_zz = compute_grad(phi_z,z)
    
    grad_phi = torch.stack([phi_x,phi_y,phi_z],1)
    norm_grad_phi = torch.sqrt(phi_x**2+phi_y**2+phi_z**2).reshape(-1,1)
    
    n = grad_phi / norm_grad_phi
    
    d_n_u = (grad_u*n).sum(1).reshape(-1,1)
    
    lapla_phi = (phi_xx+phi_yy+phi_zz).reshape(-1,1)
    
    hess_phi = torch.stack([
        phi_xx,phi_xy,phi_xz,
        phi_yx,phi_yy,phi_yz,
        phi_zx,phi_zy,phi_zz
    ],1).reshape(-1,3,3)
    
    nth_phin = bmm3(n.reshape(-1,1,3),hess_phi,n.reshape(-1,3,1)).reshape(-1,1)
    twoH = (lapla_phi - nth_phin)/norm_grad_phi
    nth_un = bmm3(n.reshape(-1,1,3),hess_u,n.reshape(-1,3,1)).reshape(-1,1)
    return lapla_u - twoH*d_n_u - nth_un


#get the values of f, where f = u_t - LaplaceBeltrami(u)
def eval_f_torch(eval_u_torch, eval_phi_torch, x,y,z,t):
    u = eval_u_torch(x,y,z,t)
    u_t = compute_grad(u,t).reshape(-1,1)
    return u_t - laplace_beltrami_torch(eval_u_torch,eval_phi_torch,x,y,z,t)


#Redefine above functions to get 'points' as inputs directly  
def laplace_beltrami_point(eval_u, eval_phi, points, t):
    x = points[:,0].clone()
    x.requires_grad = True
    y = points[:,1].clone()
    y.requires_grad = True
    z = points[:,2].clone()
    z.requires_grad = True
    return laplace_beltrami_torch(eval_u, eval_phi, x,y,z, t)

def get_loss_ge(model, batch, device):
    x, weight, basis, A_inv, derivative = batch
    x = x.to(device)
    weight = weight.to(device)
    basis = basis.to(device)
    A_inv = A_inv.to(device)
    derivative = derivative.to(device)
    
    
    u = model(x).reshape(-1,args.K)
    
    lb = laplace_beltrami_point(eval_u, eval_phi, x.reshape(-1,args.K,3)[:,0],0)
    
    #print(f.mean())
    
    b = (weight.reshape(-1,args.K,1) * basis * u.reshape(-1,args.K,1)).sum(1)
    coef_u = (A_inv * b.unsqueeze(1)).sum(-1)
    laplacian = derivative.laplacian(coef_u).reshape(-1,1)

    loss_ge = ((laplacian - lb)**2).mean()
    
    return loss_ge


In [None]:
if args.gpu == -1:
    device = torch.device('cpu')
else:
    device = torch.device('cuda:{}'.format(args.gpu))

In [None]:
X = torch.from_numpy(np.load('../../dataset/Section_42/Position_Vectors/OpenSurface_' + str(num_pts) + '_Position.npy')).float()

# min_distance = lambda cloud: min(pdist(cloud))
# md = min_distance(X_both)
# index_bd = (abs(X_both[:,0]-0.5)<0.5*md) + (abs(X_both[:,0]+0.5)<0.5*md) + (abs(X_both[:,1]-0.5)<0.5*md) + (abs(X_both[:,1]+0.5)<0.5*md)

from torch import sin, pi, cos
theta = torch.linspace(0,2*pi,100)
radius = (1-0.25*0.75**2)**0.5 * (0.75*0.75**2+0.5)
cos_theta = radius*cos(theta)
sin_theta = radius*sin(theta)
X_bd = torch.stack((torch.ones_like(theta)*0.75, cos_theta, sin_theta), 1)

u_true = eval_u_point(X)
u_true_bd = eval_u_point(X_bd).to(device)

if args.surface == 'pca':
    surface = PCALocalApproximation(args, torch.cat((X,X_bd),0))
elif args.surface == 'gnn':
    gnn_model = Model().to(device)
    _, _, state_dict = torch.load(args.gnn_dir)
    gnn_model.load_state_dict(state_dict)
    surface = GNNLocalApproximation(args, torch.cat((X,X_bd),0), gnn_model)

x = surface.X_knn[:len(X)]
weight = surface.weight[:len(X)]
basis = surface.basis[:len(X)]
coef_a = surface.coef_a[:len(X)]
tangent_vectors = surface.tangent_vectors[:len(X)]
normal_vectors = surface.normal_vectors[:len(X)]

batch_size = len(X)//3
t_N = args.t_batch_size
batches = []
for i in range(0,len(x),batch_size):
    x_batch = x[i:i+batch_size].repeat(t_N,1,1).reshape(-1,3)
    weight_batch = weight[i:i+batch_size].repeat(t_N,1)
    basis_batch = basis[i:i+batch_size].repeat(t_N,1,1)

    A_batch = (weight_batch.reshape(-1,args.K,1,1) * basis_batch.reshape(-1,args.K,6,1) * basis_batch.reshape(-1,args.K,1,6)).sum(1)
    A_inv_batch = torch.linalg.inv(A_batch)

    derivative_batch = SurfaceDerivative(coef_a[i:i+batch_size].repeat(t_N,1), 
                                         tangent_vectors[i:i+batch_size].repeat(t_N,1,1), 
                                         normal_vectors[i:i+batch_size].repeat(t_N,1))

    batches.append((x_batch, 
                    weight_batch, 
                    basis_batch,
                    A_inv_batch, 
                    derivative_batch))

In [None]:
model = SimpleNet_TimeIndependent().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=sch_Step_Size, gamma=sch_Gamma)

logs = dict()
logs['loss_ge'] = []
logs['loss_bd'] = []
logs['l2_error'] = []
logs['max_error'] = []

# print(u_true.mean(), u_true.max(), u_true.min())

epochs = args.epochs
for epoch in tqdm(range(1,epochs+1)):
    model.train()


    batch_loss_ge = 0.
    batch_loss_bd = 0.
    for batch in batches:
        optimizer.zero_grad()
        loss_ge = get_loss_ge(model, batch, device)


        u_bd = model(X_bd.to(device)).squeeze()

        loss_bd = ((u_true_bd - u_bd)**2).mean()

        loss = loss_ge + loss_bd
        loss.backward()
        optimizer.step()
        batch_loss_ge += loss_ge.item()
        batch_loss_bd += loss_bd.item()
    scheduler.step()

    logs['loss_ge'].append(batch_loss_ge / len(batches))
    logs['loss_bd'].append(batch_loss_bd / len(batches))

    if epoch%100==0:
        model.eval()
        x_test = X.to(device)
        u_pred = model(x_test).detach().cpu().reshape(-1)
        print('\n', epoch)
        logs['l2_error'].append(rel_l2_error(u_true, u_pred).item())
        logs['max_error'].append(rel_max_error(u_true, u_pred).item())

        """
        argmx = (u_pred-u_true).argmax()

        print(u_pred.shape, u_true.shape)
        print(argmx)

        print(u_pred[argmx])
        """
        #print('loss_ge',logs['loss_ge'][-1], 'loss_bd',logs['loss_bd'][-1])
        print('l2_error',logs['l2_error'][-1],'max_error',logs['max_error'][-1])




In [None]:
#torch.save((logs, model.cpu().state_dict()), 'save/' + save_dir)
save_path = './save/Section_42/OpenSurface_{}.pt'.format(num_pts)
torch.save((logs, model.cpu().state_dict()), save_path)

In [None]:
plt.figure()
plt.plot(logs['loss_ge'], label=r'$Loss_{GE}$')
plt.yscale('log')
plt.grid()
plt.legend()
plt.show()

plt.figure()
plt.plot(logs['l2_error'], label=r'$L_2 error$')
plt.plot(logs['max_error'], label=r'$L_\infty error$')
plt.yscale('log')
plt.grid()
plt.legend()
plt.show()