### Import Modules

In [None]:
#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_scatter

In [None]:
#visualize
import matplotlib.pyplot as plt
import plotly.graph_objs as go

In [None]:
#user_defined
from script.model import Model
from script.data_loader import GraphDataset, TorusDataset

In [None]:
#etc
from tqdm.notebook import tqdm
import warnings
import pickle
import pprint

### Device Setting

In [None]:
#input Device number
device_num = int(input('Cuda number : '))
assert device_num in range(4)

### Choose Device

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:'+str(device_num) if is_cuda else 'cpu')
print('Current cuda device is', device)

### Hyperparameter Setting

In [None]:
reg_coeff = 1e-2
reg_type = 'COS'
K_Main = int(input('K (Main) : '))
K_reg = int(input('K (Regularization) : '))
Is_Normalize = False

In [None]:
Batch_Size = 5
Learning_Rate = 1e-4
epochs = 10000
sch_step_size = 1000
sch_gamma = 0.5

In [None]:
tuning_dictionary = {'Batch_Size':Batch_Size,
                    'Learning_Rate':Learning_Rate,
                    'epochs':epochs,
                    'sch_step_size':sch_step_size,
                    'sch_gamma':sch_gamma,
                    'reg_coeff':reg_coeff,
                    'reg_type':reg_type,
                    'K_Main':K_Main,
                    'K_reg':K_reg,
                    'Is_Normalize':Is_Normalize}

In [None]:
pprint.pprint(tuning_dictionary)

### Create pyg dataset & dataloader

In [None]:
#train
train_dataset = TorusDataset(K_Main=K_Main, Is_Normalize=Is_Normalize) 
train_loader = pyg.loader.DataLoader(train_dataset, batch_size=Batch_Size, shuffle=True, num_workers=8)

#test_Cheese (100% sampling_ratio)
X_cheese_path = '../../dataset/Section_44/cheese_100/Position_Vectors/cheese_100_0_Position.npy'
normal_cheese_path = '../../dataset/Section_44/cheese_100/Normal_Vectors/cheese_100_0_Normal.npy'
test_dataset = GraphDataset(X_cheese_path, normal_cheese_path, K_Main=K_Main, Is_Normalize=Is_Normalize)
test_loader = pyg.loader.DataLoader(test_dataset, batch_size=1)

### Create the Model and Define Optimizer & Scheduler

In [None]:
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_Rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=sch_step_size, gamma=sch_gamma)

### Define Loss

In [None]:
#use cosine_similarity for main loss
cosine_dist = nn.CosineSimilarity()
loss_fn = lambda x_pred, x: (1-abs(cosine_dist(x_pred, x))).mean()

In [None]:
#define auxiliary loss for regularizing term to make the normal vector fields be continuous.
def reg_fn(data, y_pred, k=K_reg, Reg_Type=reg_type):
    index = data.edge_index[:,torch.arange(data.edge_index.size(1)) % K_Main < k] #K_main?
    knn_y_pred = torch.gather(y_pred, dim=0, index=index[0].unsqueeze(-1).expand(-1,3)).view(-1,k,3)
    
    if Reg_Type == 'MSE':
        return_value = ((y_pred.view(-1,1,3) - knn_y_pred)**2).mean()
    else:
        assert Reg_Type == 'COS'
        return_value = loss_fn(y_pred.view(-1,1,3), knn_y_pred)
        
    return return_value

In [None]:
#logs for recording loss
logs = dict()
logs['train_loss'] = []
logs['test_loss'] = []

### Train

In [None]:
#torch.manual_seed(tuning_dictionary['seed_number_train'])
warnings.filterwarnings(action='ignore')

for epoch in tqdm(range(1,epochs+1)):
    model.train()     
    batch_loss = 0.
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = loss_fn(y_pred, data.y) + reg_coeff*reg_fn(data, y_pred)
        loss.backward()
        optimizer.step()
        batch_loss += loss.item()
    logs['train_loss'].append(batch_loss/len(train_loader))
    scheduler.step()
    
    # test
    with torch.no_grad():
        model.eval()     
        batch_loss = 0.
        for data in test_loader:
            data = data.to(device)
            y_pred = model(data)
            loss = loss_fn(y_pred, data.y) + reg_coeff*reg_fn(data, y_pred)
            batch_loss += loss.item()
        logs['test_loss'].append(batch_loss/len(test_loader))

        
    if epoch % 5 == 0:
        print('epoch {} | train_loss: {:1.2e}, test_loss: {:1.2e}'.format(
            epoch, 
            logs['train_loss'][-1],
            logs['test_loss'][-1])
        )

### Save the trained model and the logs file.

In [None]:
torch.save((tuning_dictionary, logs, model.cpu().state_dict()), './save/trained_GNN_model.pt')

### Plotting Loss

In [None]:
plt.figure()
plt.plot(logs['train_loss'], label='train')
# plt.plot(logs['test_loss'], label='test')
plt.yscale('log')
plt.grid()
plt.legend()
#plt.show()
plt.savefig('./save/plot_Loss.png')