### Import Modules

In [None]:
import torch
import argparse
import torch.nn as nn

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import scipy
import time
import os
from pprint import pprint

In [None]:
from script.model import Model, SimpleNet_TimeDependent
from script.approximator import PCALocalApproximation, GNNLocalApproximation, SurfaceDerivative
from script.auxiliary import *

### Measure Time

In [None]:
before_time = time.time()

### Device Setting

In [None]:
device_num = int(input('Device Number : '))
assert device_num in range(4)
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:'+str(device_num) if is_cuda else 'cpu')
device_cpu = torch.device('cpu')
print('Current cuda device is', device)

### Set K

In [None]:
K = int(input('K : '))

### Choose the Domain and Open some tensors

In [None]:
npy_list = os.listdir('../../dataset/Section_46')
domain_type_list = []
for n in range(len(npy_list)):
    domain_type_list.append(npy_list[n].split('_')[0])
#pprint(domain_type_list)

In [None]:
domain_type = input("Choose the Domain Type (ex. 'centaur0') : ")
assert domain_type in domain_type_list

In [None]:
T = float(input('Terminal Time : ')) #Terminal Time

In [None]:
#Load the Position_Vectors
X_path = '../../dataset/Section_46/' + domain_type +  '_Point_Clouds_array.npy'
X = torch.from_numpy(np.load(X_path)).float()
#X = X - X.mean(dim=0) #Centering 
#X = X / torch.max(torch.abs(X)) #scaling

### Load the pretrained GNN model

In [None]:
GNN_model = Model().to(device)
_, _, state_dict = torch.load('../section_31/save/trained_GNN_model.pt')
GNN_model.load_state_dict(state_dict)

### Approximate Surfaces

In [None]:
class Args:
    pass

args = Args()
args.K = K

In [None]:
surface = GNNLocalApproximation(args, X, GNN_model)

In [None]:
x = surface.X_knn
weight = surface.weight
basis = surface.basis
coef_a = surface.coef_a
tangent_vectors = surface.tangent_vectors
normal_vectors = surface.normal_vectors

### Hyperparameters Setting

In [None]:
t_batch_size = 10 #batch size for 't' (Temporal)
batch_size = len(X)//3 ##batch size for 'xyz' (Spatial)
learning_rate = 1e-3
sch_Step_Size = 2000
sch_Gamma = 0.5
num_epochs = 20000

### Choose the source point (Randomly)

In [None]:
source = X[torch.randint(len(X),(1,))][0]
init_eval = lambda x: torch.exp(-75*((x[...,[0]]-source[0])**2 + (x[...,[1]]-source[1])**2 + (x[...,[2]]-source[2])**2))

### Declare the model and the optimizer

In [None]:
model = SimpleNet_TimeDependent(init_eval).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=sch_Step_Size, gamma=sch_Gamma)

### Define Loss function

In [None]:
def get_loss_ge(model, t, batch, device):
    x, weight, basis, A_inv, derivative = batch
    x = x.to(device)
    weight = weight.to(device)
    basis = basis.to(device)
    A_inv = A_inv.to(device)
    derivative = derivative.to(device)
    
    t = expand(t, len(x)//t_batch_size).reshape(-1,1)
    t.requires_grad = True
    t = t.to(device)
    
    u = model(t, x).reshape(-1,K)
    
    t0 = t.reshape(-1,K)[:,[0]]
    ut = compute_grad(model(t0, x.reshape(-1,K,3)[:,0]), t0)
    #f = eval_f_point(eval_u, eval_phi, x.reshape(-1,K,3)[:,0], t0.reshape(-1))
    
    #print(f.mean())
    
    b = (weight.reshape(-1,K,1) * basis * u.reshape(-1,K,1)).sum(1)
    coef_u = (A_inv * b.unsqueeze(1)).sum(-1)
    laplacian = derivative.laplacian(coef_u).reshape(-1,1)

    loss_ge = ((ut - 0.05*laplacian)**2).mean()
    return loss_ge

### Make Batch

In [None]:
batches = []
for i in range(0,len(x),batch_size):
    x_batch = x[i:i+batch_size].repeat(t_batch_size,1,1).reshape(-1,3)
    weight_batch = weight[i:i+batch_size].repeat(t_batch_size,1)
    basis_batch = basis[i:i+batch_size].repeat(t_batch_size,1,1)

    A_batch = (weight_batch.reshape(-1,K,1,1) * basis_batch.reshape(-1,K,6,1) * basis_batch.reshape(-1,K,1,6)).sum(1)
    A_inv_batch = torch.linalg.inv(A_batch)

    derivative_batch = SurfaceDerivative(coef_a[i:i+batch_size].repeat(t_batch_size,1), 
                                         tangent_vectors[i:i+batch_size].repeat(t_batch_size,1,1), 
                                         normal_vectors[i:i+batch_size].repeat(t_batch_size,1))

    batches.append((x_batch, 
                    weight_batch, 
                    basis_batch,
                    A_inv_batch, 
                    derivative_batch))

### Train

In [None]:
logs = dict()
logs['loss_ge'] = []
logs['source'] = source
for epoch in tqdm(range(1,num_epochs+1)):
    model.train()

    t = torch.rand(t_batch_size)*T

    batch_loss_ge = 0.
    for batch in batches:
        optimizer.zero_grad()
        loss_ge = get_loss_ge(model, t, batch, device)
        loss = loss_ge
        loss.backward()
        optimizer.step()
        batch_loss_ge += loss_ge.item()
    scheduler.step()

    logs['loss_ge'].append(batch_loss_ge / len(batches))

    if epoch % 100 == 0:
        print('epoch {} | loss_ge: {:1.2e}'.format(epoch, 
                                                   logs['loss_ge'][-1]))

### Plotting

In [None]:
plt.figure()
plt.plot(logs['loss_ge'], label=r'$Loss_{GE}$')
plt.yscale('log')
plt.grid()
plt.legend()
plt.show()

### Save the model

In [None]:
torch.save((logs, model.cpu().state_dict()), './save/Section_46/' + domain_type + '/' + domain_type + '.pt')

### Measure Time

In [None]:
after_time = time.time()
how_long = int(after_time - before_time)
print('{}hr {}min {}sec'.format(how_long//3600, (how_long%3600)//60, (how_long%3600)%60))