### Import Modules

In [None]:
#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_scatter

In [None]:
#visualize
import matplotlib.pyplot as plt
import plotly.graph_objs as go

In [None]:
#user_defined
from script.model import Model
from script.data_loader import GraphDataset, TorusDataset

In [None]:
#etc
from tqdm.notebook import tqdm
import warnings
import pickle
import pprint

### Device Setting

In [None]:
#input Device number
device_num = int(input('Cuda number : '))
assert device_num in range(4)

### Choose Device

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda:'+str(device_num) if is_cuda else 'cpu')
print('Current cuda device is', device)

### Hyperparameter Setting

In [None]:
reg_coeff = 1e-2
reg_type = 'COS'
K_Main = int(input('K (Main) : '))
K_reg = int(input('K (Regularization) : '))
Is_Normalize = False

In [None]:
Batch_Size = 5
Learning_Rate = 1e-4
epochs = 10000
sch_step_size = 1000
sch_gamma = 0.5

In [None]:
tuning_dictionary = {'Batch_Size':Batch_Size,
                    'Learning_Rate':Learning_Rate,
                    'epochs':epochs,
                    'sch_step_size':sch_step_size,
                    'sch_gamma':sch_gamma,
                    'reg_coeff':reg_coeff,
                    'reg_type':reg_type,
                    'K_Main':K_Main,
                    'K_reg':K_reg,
                    'Is_Normalize':Is_Normalize}

In [None]:
pprint.pprint(tuning_dictionary)

### Create pyg dataset & dataloader

In [None]:
#train
train_dataset = TorusDataset(K_Main=K_Main, Is_Normalize=Is_Normalize) 
train_loader = pyg.loader.DataLoader(train_dataset, batch_size=Batch_Size, shuffle=True, num_workers=8)

#test_Cheese (100% sampling_ratio)
X_cheese_path = '../../dataset/Section_44/cheese_100/Position_Vectors/cheese_100_0_Position.npy'
normal_cheese_path = '../../dataset/Section_44/cheese_100/Normal_Vectors/cheese_100_0_Normal.npy'
test_dataset = GraphDataset(X_cheese_path, normal_cheese_path, K_Main=K_Main, Is_Normalize=Is_Normalize)
test_loader = pyg.loader.DataLoader(test_dataset, batch_size=1)

### Create the Model and Define Optimizer & Scheduler

In [None]:
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_Rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=sch_step_size, gamma=sch_gamma)

### Define Loss

In [None]:
#use cosine_similarity for main loss
cosine_dist = nn.CosineSimilarity()
loss_fn = lambda x_pred, x: (1-abs(cosine_dist(x_pred, x))).mean()

In [None]:
#define auxiliary loss for regularizing term to make the normal vector fields be continuous.
def reg_fn(data, y_pred, k=K_reg, Reg_Type=reg_type):
    index = data.edge_index[:,torch.arange(data.edge_index.size(1)) % K_Main < k] #K_main?
    knn_y_pred = torch.gather(y_pred, dim=0, index=index[0].unsqueeze(-1).expand(-1,3)).view(-1,k,3)
    
    if Reg_Type == 'MSE':
        return_value = ((y_pred.view(-1,1,3) - knn_y_pred)**2).mean()
    else:
        assert Reg_Type == 'COS'
        return_value = loss_fn(y_pred.view(-1,1,3), knn_y_pred)
        
    return return_value

In [None]:
#logs for recording loss
logs = dict()
logs['train_loss'] = []
logs['test_loss'] = []

### Train

In [None]:
#torch.manual_seed(tuning_dictionary['seed_number_train'])
warnings.filterwarnings(action='ignore')

for epoch in tqdm(range(1,epochs+1)):
    model.train()     
    batch_loss = 0.
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = loss_fn(y_pred, data.y) + reg_coeff*reg_fn(data, y_pred)
        loss.backward()
        optimizer.step()
        batch_loss += loss.item()
    logs['train_loss'].append(batch_loss/len(train_loader))
    scheduler.step()
    
    # test
    with torch.no_grad():
        model.eval()     
        batch_loss = 0.
        for data in test_loader:
            data = data.to(device)
            y_pred = model(data)
            loss = loss_fn(y_pred, data.y) + reg_coeff*reg_fn(data, y_pred)
            batch_loss += loss.item()
        logs['test_loss'].append(batch_loss/len(test_loader))

        
    if epoch % 5 == 0:
        print('epoch {} | train_loss: {:1.2e}, test_loss: {:1.2e}'.format(
            epoch, 
            logs['train_loss'][-1],
            logs['test_loss'][-1])
        )

epoch 3755 | train_loss: 7.50e-03, test_loss: 2.07e-03
epoch 3760 | train_loss: 7.71e-03, test_loss: 2.37e-03
epoch 3765 | train_loss: 8.14e-03, test_loss: 2.37e-03
epoch 3770 | train_loss: 7.91e-03, test_loss: 2.42e-03
epoch 3775 | train_loss: 9.02e-03, test_loss: 2.19e-03
epoch 3780 | train_loss: 7.59e-03, test_loss: 2.21e-03
epoch 3785 | train_loss: 7.54e-03, test_loss: 1.87e-03
epoch 3790 | train_loss: 1.31e-02, test_loss: 7.13e-03
epoch 3795 | train_loss: 8.37e-03, test_loss: 3.53e-03
epoch 3800 | train_loss: 7.85e-03, test_loss: 2.41e-03
epoch 3805 | train_loss: 7.93e-03, test_loss: 2.48e-03
epoch 3810 | train_loss: 7.17e-03, test_loss: 2.22e-03
epoch 3815 | train_loss: 7.85e-03, test_loss: 2.08e-03
epoch 3820 | train_loss: 7.75e-03, test_loss: 2.30e-03
epoch 3825 | train_loss: 8.57e-03, test_loss: 2.35e-03
epoch 3830 | train_loss: 8.29e-03, test_loss: 2.22e-03
epoch 3835 | train_loss: 7.08e-03, test_loss: 2.02e-03
epoch 3840 | train_loss: 7.35e-03, test_loss: 2.10e-03
epoch 3845

epoch 4500 | train_loss: 6.64e-03, test_loss: 1.84e-03
epoch 4505 | train_loss: 6.30e-03, test_loss: 1.86e-03
epoch 4510 | train_loss: 6.46e-03, test_loss: 1.79e-03
epoch 4515 | train_loss: 6.46e-03, test_loss: 1.76e-03
epoch 4520 | train_loss: 6.45e-03, test_loss: 1.90e-03
epoch 4525 | train_loss: 6.48e-03, test_loss: 1.94e-03
epoch 4530 | train_loss: 6.26e-03, test_loss: 1.83e-03
epoch 4535 | train_loss: 7.19e-03, test_loss: 1.82e-03
epoch 4540 | train_loss: 6.69e-03, test_loss: 1.96e-03
epoch 4545 | train_loss: 6.65e-03, test_loss: 1.97e-03
epoch 4550 | train_loss: 7.19e-03, test_loss: 2.16e-03
epoch 4555 | train_loss: 6.67e-03, test_loss: 1.89e-03
epoch 4560 | train_loss: 6.23e-03, test_loss: 1.76e-03
epoch 4565 | train_loss: 6.17e-03, test_loss: 1.75e-03
epoch 4570 | train_loss: 6.08e-03, test_loss: 1.94e-03
epoch 4575 | train_loss: 6.50e-03, test_loss: 1.94e-03
epoch 4580 | train_loss: 6.50e-03, test_loss: 1.75e-03
epoch 4585 | train_loss: 6.29e-03, test_loss: 1.74e-03
epoch 4590

epoch 5245 | train_loss: 6.54e-03, test_loss: 1.86e-03
epoch 5250 | train_loss: 6.41e-03, test_loss: 1.66e-03
epoch 5255 | train_loss: 5.77e-03, test_loss: 1.60e-03
epoch 5260 | train_loss: 6.05e-03, test_loss: 1.64e-03
epoch 5265 | train_loss: 5.98e-03, test_loss: 1.73e-03
epoch 5270 | train_loss: 5.90e-03, test_loss: 1.68e-03
epoch 5275 | train_loss: 6.00e-03, test_loss: 1.75e-03
epoch 5280 | train_loss: 6.36e-03, test_loss: 1.69e-03
epoch 5285 | train_loss: 6.21e-03, test_loss: 1.71e-03
epoch 5290 | train_loss: 6.26e-03, test_loss: 1.63e-03
epoch 5295 | train_loss: 6.23e-03, test_loss: 1.64e-03
epoch 5300 | train_loss: 6.44e-03, test_loss: 1.80e-03
epoch 5305 | train_loss: 6.24e-03, test_loss: 1.63e-03
epoch 5310 | train_loss: 6.19e-03, test_loss: 1.69e-03
epoch 5315 | train_loss: 6.40e-03, test_loss: 1.74e-03
epoch 5320 | train_loss: 6.04e-03, test_loss: 1.70e-03
epoch 5325 | train_loss: 6.08e-03, test_loss: 1.70e-03
epoch 5330 | train_loss: 5.85e-03, test_loss: 1.63e-03
epoch 5335

epoch 5990 | train_loss: 5.77e-03, test_loss: 1.56e-03
epoch 5995 | train_loss: 5.83e-03, test_loss: 1.62e-03
epoch 6000 | train_loss: 6.15e-03, test_loss: 1.57e-03
epoch 6005 | train_loss: 5.63e-03, test_loss: 1.56e-03
epoch 6010 | train_loss: 5.58e-03, test_loss: 1.54e-03
epoch 6015 | train_loss: 6.03e-03, test_loss: 1.55e-03
epoch 6020 | train_loss: 5.95e-03, test_loss: 1.61e-03
epoch 6025 | train_loss: 6.00e-03, test_loss: 1.56e-03
epoch 6030 | train_loss: 5.80e-03, test_loss: 1.55e-03
epoch 6035 | train_loss: 5.89e-03, test_loss: 1.70e-03
epoch 6040 | train_loss: 5.86e-03, test_loss: 1.64e-03
epoch 6045 | train_loss: 5.74e-03, test_loss: 1.56e-03
epoch 6050 | train_loss: 5.64e-03, test_loss: 1.67e-03
epoch 6055 | train_loss: 5.68e-03, test_loss: 1.60e-03
epoch 6060 | train_loss: 5.90e-03, test_loss: 1.54e-03
epoch 6065 | train_loss: 5.59e-03, test_loss: 1.56e-03
epoch 6070 | train_loss: 5.84e-03, test_loss: 1.59e-03
epoch 6075 | train_loss: 5.88e-03, test_loss: 1.54e-03
epoch 6080

epoch 6735 | train_loss: 5.60e-03, test_loss: 1.95e-03
epoch 6740 | train_loss: 5.58e-03, test_loss: 1.69e-03
epoch 6745 | train_loss: 5.70e-03, test_loss: 1.61e-03
epoch 6750 | train_loss: 5.86e-03, test_loss: 1.64e-03
epoch 6755 | train_loss: 5.76e-03, test_loss: 1.73e-03
epoch 6760 | train_loss: 5.62e-03, test_loss: 1.66e-03
epoch 6765 | train_loss: 5.90e-03, test_loss: 1.66e-03
epoch 6770 | train_loss: 5.59e-03, test_loss: 1.56e-03
epoch 6775 | train_loss: 5.70e-03, test_loss: 1.53e-03
epoch 6780 | train_loss: 5.69e-03, test_loss: 1.58e-03
epoch 6785 | train_loss: 5.53e-03, test_loss: 1.55e-03
epoch 6790 | train_loss: 5.78e-03, test_loss: 1.56e-03
epoch 6795 | train_loss: 5.85e-03, test_loss: 1.56e-03
epoch 6800 | train_loss: 5.96e-03, test_loss: 1.56e-03
epoch 6805 | train_loss: 5.69e-03, test_loss: 1.62e-03
epoch 6810 | train_loss: 5.60e-03, test_loss: 1.58e-03
epoch 6815 | train_loss: 5.82e-03, test_loss: 1.50e-03
epoch 6820 | train_loss: 5.69e-03, test_loss: 1.60e-03
epoch 6825

epoch 7480 | train_loss: 5.54e-03, test_loss: 1.52e-03
epoch 7485 | train_loss: 5.96e-03, test_loss: 1.50e-03
epoch 7490 | train_loss: 5.57e-03, test_loss: 1.54e-03
epoch 7495 | train_loss: 5.60e-03, test_loss: 1.68e-03
epoch 7500 | train_loss: 5.69e-03, test_loss: 1.49e-03
epoch 7505 | train_loss: 5.69e-03, test_loss: 1.56e-03
epoch 7510 | train_loss: 5.94e-03, test_loss: 1.58e-03
epoch 7515 | train_loss: 5.67e-03, test_loss: 1.56e-03
epoch 7520 | train_loss: 5.60e-03, test_loss: 1.51e-03
epoch 7525 | train_loss: 5.62e-03, test_loss: 1.55e-03
epoch 7530 | train_loss: 5.91e-03, test_loss: 1.52e-03
epoch 7535 | train_loss: 5.66e-03, test_loss: 1.51e-03
epoch 7540 | train_loss: 5.77e-03, test_loss: 1.55e-03
epoch 7545 | train_loss: 5.66e-03, test_loss: 1.58e-03
epoch 7550 | train_loss: 5.98e-03, test_loss: 1.51e-03
epoch 7555 | train_loss: 5.54e-03, test_loss: 1.59e-03
epoch 7560 | train_loss: 5.61e-03, test_loss: 1.57e-03
epoch 7565 | train_loss: 5.60e-03, test_loss: 1.60e-03
epoch 7570

epoch 8225 | train_loss: 5.68e-03, test_loss: 1.51e-03
epoch 8230 | train_loss: 5.66e-03, test_loss: 1.48e-03
epoch 8235 | train_loss: 5.76e-03, test_loss: 1.58e-03
epoch 8240 | train_loss: 5.58e-03, test_loss: 1.53e-03
epoch 8245 | train_loss: 5.69e-03, test_loss: 1.60e-03
epoch 8250 | train_loss: 5.59e-03, test_loss: 1.65e-03
epoch 8255 | train_loss: 5.89e-03, test_loss: 1.54e-03
epoch 8260 | train_loss: 5.93e-03, test_loss: 1.61e-03
epoch 8265 | train_loss: 5.51e-03, test_loss: 1.69e-03
epoch 8270 | train_loss: 5.49e-03, test_loss: 1.68e-03
epoch 8275 | train_loss: 5.79e-03, test_loss: 1.52e-03
epoch 8280 | train_loss: 5.59e-03, test_loss: 1.59e-03
epoch 8285 | train_loss: 5.61e-03, test_loss: 1.56e-03
epoch 8290 | train_loss: 5.63e-03, test_loss: 1.50e-03
epoch 8295 | train_loss: 5.50e-03, test_loss: 1.56e-03
epoch 8300 | train_loss: 5.54e-03, test_loss: 1.60e-03
epoch 8305 | train_loss: 5.57e-03, test_loss: 1.63e-03
epoch 8310 | train_loss: 5.76e-03, test_loss: 1.50e-03
epoch 8315

epoch 8970 | train_loss: 5.82e-03, test_loss: 1.55e-03
epoch 8975 | train_loss: 5.52e-03, test_loss: 1.57e-03
epoch 8980 | train_loss: 5.44e-03, test_loss: 1.65e-03
epoch 8985 | train_loss: 5.32e-03, test_loss: 1.55e-03
epoch 8990 | train_loss: 5.95e-03, test_loss: 1.54e-03
epoch 8995 | train_loss: 5.63e-03, test_loss: 1.63e-03
epoch 9000 | train_loss: 5.43e-03, test_loss: 1.56e-03
epoch 9005 | train_loss: 5.37e-03, test_loss: 1.65e-03
epoch 9010 | train_loss: 5.68e-03, test_loss: 1.50e-03
epoch 9015 | train_loss: 5.32e-03, test_loss: 1.56e-03
epoch 9020 | train_loss: 5.49e-03, test_loss: 1.48e-03
epoch 9025 | train_loss: 5.22e-03, test_loss: 1.60e-03
epoch 9030 | train_loss: 5.77e-03, test_loss: 1.48e-03
epoch 9035 | train_loss: 5.39e-03, test_loss: 1.49e-03
epoch 9040 | train_loss: 5.73e-03, test_loss: 1.48e-03
epoch 9045 | train_loss: 5.37e-03, test_loss: 1.50e-03
epoch 9050 | train_loss: 6.00e-03, test_loss: 1.57e-03
epoch 9055 | train_loss: 5.64e-03, test_loss: 1.53e-03
epoch 9060

epoch 9715 | train_loss: 5.75e-03, test_loss: 1.57e-03
epoch 9720 | train_loss: 5.51e-03, test_loss: 1.59e-03
epoch 9725 | train_loss: 5.72e-03, test_loss: 1.48e-03
epoch 9730 | train_loss: 5.77e-03, test_loss: 1.51e-03
epoch 9735 | train_loss: 5.50e-03, test_loss: 1.48e-03
epoch 9740 | train_loss: 5.62e-03, test_loss: 1.54e-03
epoch 9745 | train_loss: 5.31e-03, test_loss: 1.55e-03
epoch 9750 | train_loss: 5.48e-03, test_loss: 1.50e-03
epoch 9755 | train_loss: 5.52e-03, test_loss: 1.50e-03
epoch 9760 | train_loss: 5.65e-03, test_loss: 1.57e-03
epoch 9765 | train_loss: 5.45e-03, test_loss: 1.50e-03
epoch 9770 | train_loss: 5.52e-03, test_loss: 1.54e-03
epoch 9775 | train_loss: 5.54e-03, test_loss: 1.54e-03
epoch 9780 | train_loss: 5.54e-03, test_loss: 1.57e-03
epoch 9785 | train_loss: 5.61e-03, test_loss: 1.62e-03
epoch 9790 | train_loss: 5.59e-03, test_loss: 1.61e-03
epoch 9795 | train_loss: 5.63e-03, test_loss: 1.65e-03
epoch 9800 | train_loss: 5.56e-03, test_loss: 1.49e-03
epoch 9805

### Save the trained model and the logs file.

In [None]:
torch.save((tuning_dictionary, logs, model.cpu().state_dict()), './save/trained_GNN_model.pt')

### Plotting Loss

In [None]:
plt.figure()
plt.plot(logs['train_loss'], label='train')
# plt.plot(logs['test_loss'], label='test')
plt.yscale('log')
plt.grid()
plt.legend()
#plt.show()
plt.savefig('./save/plot_Loss.png')