### Import Modules

In [None]:
import torch
import argparse
import torch.nn as nn

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import scipy
import time

In [None]:
from script.model import Model, SimpleNet_TimeDependent
from script.approximator import PCALocalApproximation, GNNLocalApproximation, SurfaceDerivative
from script.auxiliary import *

### Measure Time

In [None]:
before_time = time.time()

### Device Setting

In [None]:
device_num = int(input('Device Number : '))
assert device_num in range(4)
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:'+str(device_num) if is_cuda else 'cpu')
device_cpu = torch.device('cpu')
print('Current cuda device is', device)

### Set K

In [None]:
K = int(input('K : '))

### Choose the Domain and Open some tensors

In [None]:
domain_type = 'domain_A'
#domain_type = 'domain_B'
#domain_type = 'domain_C'

In [None]:
if domain_type == 'domain_A':
    num_pts_list = [7518, 5220, 2934, 1285]
elif domain_type == 'domain_B':
    num_pts_list = [14402, 2310, 1036, 578]
else:
    assert domain_type == 'domain_C'
    num_pts_list = [21896, 16694, 10686, 7446]

In [None]:
#Choose h
num_pts = int(input('num_pts : '))
assert num_pts in num_pts_list

In [None]:
#Load the Position_Vectors
X_path = '../../dataset/Section_43/' + domain_type + '/Position_Vectors/' + domain_type[-1] + '_' + str(num_pts) + '_Position.npy'
X = torch.from_numpy(np.load(X_path)).float()

### Load the pretrained GNN model

In [None]:
GNN_model = Model().to(device)
_, _, state_dict = torch.load('../section_31/save/trained_GNN_model.pt')
GNN_model.load_state_dict(state_dict)

### Approximate Surfaces

In [None]:
class Args:
    pass

args = Args()
args.K = K

In [None]:
surface = GNNLocalApproximation(args, X, GNN_model)

In [None]:
x = surface.X_knn
weight = surface.weight
basis = surface.basis
coef_a = surface.coef_a
tangent_vectors = surface.tangent_vectors
normal_vectors = surface.normal_vectors

### Hyperparameters Setting

In [None]:
t_batch_size = 10 #batch size for 't' (Temporal)
batch_size = len(X)//3 ##batch size for 'xyz' (Spatial)
learning_rate = 1e-3
sch_Step_Size = 2000
sch_Gamma = 0.5
num_epochs = 20000
T = 1. #Terminal Time

### Define u_exact and phi_exact

In [None]:
#Define analytic form of u_true and phi
def eval_u(x,y,z,t):
    return torch.sin(x+torch.sin(t))*torch.exp(torch.cos(y-z))

In [None]:
init_eval = lambda x: torch.sin(x[:,[0]]) * torch.exp(torch.cos(x[:,[1]] - x[:,[2]]))
u_true_eval = lambda t,x: torch.sin(x[:,[0]] + torch.sin(t)) * torch.exp(torch.cos(x[:,[1]] - x[:,[2]]))

In [None]:
if domain_type == 'domain_A':
    def eval_phi(x,y,z):
        return x**2 + y**2 + z**2 - 1
elif domain_type == 'domain_B':
    def eval_phi(x,y,z):
        return torch.sqrt(x**2 + y**2 + z**2) - 1 + 0.4*x/torch.sqrt(x**2 + y**2 + z**2) * (4*z**2/(x**2 + y**2 + z**2) - 1)
else:
    assert domain_type == 'domain_C'
    def eval_phi(x,y,z):
        r2 = x**2 + y**2 + z**2
        return torch.sqrt(r2) - 1 - 0.4*x/torch.sqrt(r2) * (5 - 20*(x**2+y**2)/r2 + 16*(x**2+y**2)**2/r2**2)

### Declare the model and the optimizer

In [None]:
model = SimpleNet_TimeDependent(init_eval).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=sch_Step_Size, gamma=sch_Gamma)

### Define Loss function

In [None]:
def get_loss_ge(model, t, batch, device):
    x, weight, basis, A_inv, derivative = batch
    x = x.to(device)
    weight = weight.to(device)
    basis = basis.to(device)
    A_inv = A_inv.to(device)
    derivative = derivative.to(device)
    
    t = expand(t, len(x)//t_batch_size).reshape(-1,1)
    t.requires_grad = True
    t = t.to(device)
    
    u = model(t, x).reshape(-1,K)
    
    t0 = t.reshape(-1,K)[:,[0]]
    ut = compute_grad(model(t0, x.reshape(-1,K,3)[:,0]), t0)
    f = eval_f_point(eval_u, eval_phi, x.reshape(-1,K,3)[:,0], t0.reshape(-1))
    
    #print(f.mean())
    
    b = (weight.reshape(-1,K,1) * basis * u.reshape(-1,K,1)).sum(1)
    coef_u = (A_inv * b.unsqueeze(1)).sum(-1)
    laplacian = derivative.laplacian(coef_u).reshape(-1,1)

    loss_ge = ((ut - laplacian - f)**2).mean()
    return loss_ge

### Make Batch

In [None]:
u_true = u_true_eval(torch.ones(len(X),1)*T, X)
batches = []
for i in range(0,len(x),batch_size):
    x_batch = x[i:i+batch_size].repeat(t_batch_size,1,1).reshape(-1,3)
    weight_batch = weight[i:i+batch_size].repeat(t_batch_size,1)
    basis_batch = basis[i:i+batch_size].repeat(t_batch_size,1,1)

    A_batch = (weight_batch.reshape(-1,K,1,1) * basis_batch.reshape(-1,K,6,1) * basis_batch.reshape(-1,K,1,6)).sum(1)
    A_inv_batch = torch.linalg.inv(A_batch)

    derivative_batch = SurfaceDerivative(coef_a[i:i+batch_size].repeat(t_batch_size,1), 
                                         tangent_vectors[i:i+batch_size].repeat(t_batch_size,1,1), 
                                         normal_vectors[i:i+batch_size].repeat(t_batch_size,1))

    batches.append((x_batch, 
                    weight_batch, 
                    basis_batch,
                    A_inv_batch, 
                    derivative_batch))

### Train

In [None]:
logs = dict()
logs['loss_ge'] = []
logs['l2_error'] = []
logs['max_error'] = []

for epoch in tqdm(range(1,num_epochs+1)):
    model.train()

    t = torch.rand(t_batch_size)*T

    batch_loss_ge = 0.
    for batch in batches:
        optimizer.zero_grad()
        loss_ge = get_loss_ge(model, t, batch, device)
        loss = loss_ge
        loss.backward()
        optimizer.step()
        batch_loss_ge += loss_ge.item()
    scheduler.step()

    logs['loss_ge'].append(batch_loss_ge / len(batches))

    if epoch % 100 == 0:
        model.eval()
        t_test = (torch.ones(len(X),1)*T).to(device)
        x_test = X.to(device)
        u_pred = model(t_test, x_test).detach().cpu()
        logs['l2_error'].append(rel_l2_error(u_true, u_pred).item())
        logs['max_error'].append(rel_max_error(u_true, u_pred).item())
        
        print('epoch {} | loss_ge: {:1.2e} l2_error: {:1.2e} max_error: {:1.2e}'.format(
            epoch, logs['loss_ge'][-1], logs['l2_error'][-1], logs['max_error'][-1]))


### Plotting

In [None]:
plt.figure()
plt.plot(logs['loss_ge'], label=r'$Loss_{GE}$')
plt.yscale('log')
plt.grid()
plt.legend()
plt.show()

plt.figure()
plt.plot(logs['l2_error'], label=r'$L_2 error$')
plt.plot(logs['max_error'], label=r'$L_\infty error$')
plt.yscale('log')
plt.grid()
plt.legend()
plt.show()

### Save

In [None]:
save_path = './save/Section_43/' + domain_type + '/{}_{}.pt'.format(domain_type[-1], num_pts)
torch.save((logs, model.to(device_cpu).state_dict()), save_path)

### Check the Error

In [None]:
logs, load = torch.load(save_path)
model = SimpleNet_TimeDependent(init_eval).to(device)
model.load_state_dict(load)

In [None]:
print('Rel_L2_Error : ' + "%.2e"%logs['l2_error'][-1])
print('Rel_Max_Error : ' + "%.2e"%logs['max_error'][-1])

### Measure Time

In [None]:
after_time = time.time()
how_long = int(after_time - before_time)
print('{}hr {}min {}sec'.format(how_long//3600, (how_long%3600)//60, (how_long%3600)%60))