In [76]:
import math
import numpy as np
import torch
import sys
import torch.nn as nn
import torch.nn.functional as F
from timeit import default_timer
import sys
import os
sys.path.append("../../")
from utility.adam import Adam
from utility.losses import LpLoss
from utility.normalizer import UnitGaussianNormalizer
from pcno.geo_utility import compute_edge_gradient_weights

In [77]:
data_path = "../../data/curve/"
data = np.load(data_path+"pcno_curve_data_3_3_grad.npz")

equal_weights = False
nnodes, node_mask, nodes = data["nnodes"], data["node_mask"], data["nodes"]
node_weights = data["node_measures_raw"]
# print('use node_weight')
node_weights = node_weights/np.amax(np.sum(node_weights, axis = 1))
print('use normalized raw measures')
directed_edges, edge_gradient_weights = data["directed_edges"], data["edge_gradient_weights"]
features = data["features"]
node_measures = data["node_measures"]
node_measures_raw = data["node_measures_raw"]
indices = np.isfinite(node_measures_raw)
node_rhos = np.copy(node_weights)
node_rhos[indices] = node_rhos[indices]/node_measures[indices]


nnodes = torch.from_numpy(nnodes)
node_mask = torch.from_numpy(node_mask)
nodes = torch.from_numpy(nodes.astype(np.float32))
node_weights = torch.from_numpy(node_weights.astype(np.float32))
node_rhos = torch.from_numpy(node_rhos.astype(np.float32))
features = torch.from_numpy(features.astype(np.float32))
directed_edges = torch.from_numpy(directed_edges.astype(np.int64))
edge_gradient_weights = torch.from_numpy(edge_gradient_weights.astype(np.float32))

nodes_input = nodes.clone()
N = 1000
n_train, n_test = 900, 100


# x_train, x_test = torch.cat((features[:n_train, :, :1], nodes_input[:n_train, ...], node_rhos[:n_train, ...]), -1), torch.cat((features[-n_test:, :, :1],nodes_input[-n_test:, ...], node_rhos[-n_test:, ...]),-1)

# aux_train       = (node_mask[0:n_train,...], nodes[0:n_train,...], node_weights[0:n_train,...], directed_edges[0:n_train,...], edge_gradient_weights[0:n_train,...])
# aux_test        = (node_mask[-n_test:,...],  nodes[-n_test:,...],  node_weights[-n_test:,...],  directed_edges[-n_test:,...],  edge_gradient_weights[-n_test:,...])

# y_train, y_test = features[:n_train, :, 1:],     features[-n_test:, :, 1:]
x_all = torch.cat((features[:, :, :1], nodes_input, node_rhos), -1)
y_all = features[:, :, 1:]
aux_all = (node_mask, nodes, node_weights, directed_edges, edge_gradient_weights)

normalization_x = False
normalization_y = True
normalization_dim_x = []
normalization_dim_y = []
non_normalized_dim_x = 4
non_normalized_dim_y = 0

config = {"train" : {"normalization_x": normalization_x,"normalization_y": normalization_y, 
                     "normalization_dim_x": normalization_dim_x, "normalization_dim_y": normalization_dim_y, 
                     "non_normalized_dim_x": non_normalized_dim_x, "non_normalized_dim_y": non_normalized_dim_y}
                     }

use normalized raw measures


In [78]:
import matplotlib.pyplot as plt
from contextlib import redirect_stdout
from pcno.pcno import compute_Fourier_modes, PCNO, PCNO_train
from generate_curves_data import compute_unit_normals
import matplotlib as mpl

mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['font.size'] = 17
def test_normalizer(all_tuple, config, model_n_train):
    x_all,y_all,node_mask,nodes,node_weights,directed_edges,edge_gradient_weights = all_tuple
    x_all_copy = x_all.clone()
    y_all_copy = y_all.clone()    

    normalization_x, normalization_y = config["train"]["normalization_x"], config["train"]["normalization_y"]
    normalization_dim_x, normalization_dim_y = config["train"]["normalization_dim_x"], config["train"]["normalization_dim_y"]
    non_normalized_dim_x, non_normalized_dim_y = config["train"]["non_normalized_dim_x"], config["train"]["non_normalized_dim_y"]


    x_train = x_all_copy[:model_n_train]
    y_train = y_all_copy[:model_n_train]

    if normalization_x:
        x_normalizer = UnitGaussianNormalizer(x_train, non_normalized_dim = non_normalized_dim_x, normalization_dim=normalization_dim_x)
        x_test = x_normalizer.encode(x_all_copy)
    else:
        x_normalizer = None
        x_test =  x_all_copy
    if normalization_y:
        y_normalizer = UnitGaussianNormalizer(y_train, non_normalized_dim = non_normalized_dim_y, normalization_dim=normalization_dim_y)
        y_test = y_normalizer.encode(y_all_copy)
        y_normalizer.to(device)
    else:
        y_normalizer = None
        y_test = y_all_copy
    test_tuple = (x_test, y_test, node_mask,nodes,node_weights,directed_edges,edge_gradient_weights)

    return test_tuple, y_normalizer

def test_model(model,n_test,test_tuple, y_normalizer):

    x_test,y_test,node_mask,nodes,node_weights,directed_edges,edge_gradient_weights = test_tuple

    rel_l2 = []
    index = []
    myloss = LpLoss(d=1, p=2, size_average=False)

    with torch.no_grad():
        for i in range(N-n_test,N):
            
            x, y = x_test[i:i+1].to(device), y_test[i:i+1].to(device)
            aux_batch = (
            node_mask[i:i+1].to(device), nodes[i:i+1].to(device),
            node_weights[i:i+1].to(device), directed_edges[i:i+1].to(device),
            edge_gradient_weights[i:i+1].to(device)
            )

            out = model(x, aux_batch) #.reshape(batch_size_,  -1)
            if y_normalizer:
                out = y_normalizer.decode(out)
                y = y_normalizer.decode(y)
            batch_size_ = x.shape[0]
            out = out * node_mask[i:i+1].to(device) #mask the padded value with 0,(1 for node, 0 for padding)
            test_rel_l2 = myloss(out.view(batch_size_,-1), y.view(batch_size_,-1)).item()
            # test_l2 = myloss.abs(out.view(batch_size_,-1), y.view(batch_size_,-1)).item()

            rel_l2.append(test_rel_l2)
            index.append(i)            
            print(f'test index: {i}, test_rel_l2: {test_rel_l2}')
    return  rel_l2, index

def sorted_result( rel_l2, index):


    sorted_l2 = sorted(enumerate(rel_l2), key=lambda x: x[1], reverse=True)
    average_loss = sum(rel_l2)/len(rel_l2)
    print('average_rel_l2_of all :  ',round(average_loss,5), flush = True)
    print()
    n = 3
    index_3 = [index[sorted_l2[0][0]],index[sorted_l2[len(sorted_l2)//2][0]],index[sorted_l2[-1][0]]]
    for j in range(n):
        print(f'{j+1}th_worst_rel_l2_of all :  ',round(sorted_l2[j][1],5), ' index : ',index[sorted_l2[j][0]])
        print('medium_rel_l2_of all : ',round(sorted_l2[len(sorted_l2)//2][1],5), ' index : ',index_3[1])
    for j in range(n):
        print(f'{j+1}th_best_rel_l2_of all :  ',round(sorted_l2[-j-1][1],5), ' index : ',index[sorted_l2[-j-1][0]],flush = True)
    print()
    return average_loss,index_3



def myplot(index_plot,save_figure_path,
           model,test_tuple,y_normalizer):

    x_test,y_test,node_mask,nodes,node_weights,directed_edges,edge_gradient_weights = test_tuple
    myloss = LpLoss(d=1, p=2, size_average=False)


    with torch.no_grad():

        fig, axs = plt.subplots(len(index_plot), 4, figsize=(20, 4*len(index_plot)))

        for j in range(len(index_plot)):
            i = index_plot[j]
            x, y = x_test[i:i+1].to(device), y_test[i:i+1].to(device)
            aux_batch = (
            node_mask[i:i+1].to(device), nodes[i:i+1].to(device),
            node_weights[i:i+1].to(device), directed_edges[i:i+1].to(device),
            edge_gradient_weights[i:i+1].to(device)
            )
            out = model(x, aux_batch)
            if y_normalizer:
                out = y_normalizer.decode(out)
                y = y_normalizer.decode(y)
            batch_size_ = x.shape[0]
            out = out * node_mask[i:i+1].to(device)
            test_rel_l2 = myloss(out.view(batch_size_,-1), y.view(batch_size_,-1)).item()

            nodes_plot = nodes[i].detach().cpu()
            normal_plot = compute_unit_normals(nodes_plot, None)
            f_plot = x_test[i:i+1][:,:,0].reshape(-1).detach().cpu()
            g_plot = y.reshape(-1).detach().cpu()
            out_plot = out.reshape(-1).detach().cpu()
            error_plot = out_plot - g_plot

            # Consistent color scale for g and prediction
            vmin_go = min(g_plot.min().item(), out_plot.min().item())
            vmax_go = max(g_plot.max().item(), out_plot.max().item())
            norm_go = mpl.colors.Normalize(vmin=vmin_go, vmax=vmax_go)

            # Symmetric color scale around 0 for error
            vmax_err = torch.max(torch.abs(error_plot)).item()
            norm_err = mpl.colors.TwoSlopeNorm(vmin=-vmax_err, vcenter=0.0, vmax=vmax_err)

            axs[j,0].plot(nodes_plot[:, 0], nodes_plot[:, 1], color='blue', alpha=0.5)
            scatter_f = axs[j,0].scatter(nodes_plot[:, 0], nodes_plot[:, 1], c=f_plot, cmap='viridis', s=40)
            axs[j,0].quiver(nodes_plot[:, 0], nodes_plot[:, 1], normal_plot[:, 0], normal_plot[:, 1], color='red', scale=10, width=0.001, alpha=0.7)
            axs[j,0].set_title('Input f(x)')
            axs[j,0].axis('equal')
            fig.colorbar(scatter_f, ax=axs[j,0])

            axs[j,1].plot(nodes_plot[:, 0], nodes_plot[:, 1], color='blue', alpha=0.5)
            scatter_g = axs[j,1].scatter(nodes_plot[:, 0], nodes_plot[:, 1], c=g_plot, cmap='viridis', s=40, norm=norm_go)
            axs[j,1].set_title('Ground Truth g(x)')
            axs[j,1].axis('equal')
            fig.colorbar(scatter_g, ax=axs[j,1])

            axs[j,2].plot(nodes_plot[:, 0], nodes_plot[:, 1], color='blue', alpha=0.5)
            scatter_out = axs[j,2].scatter(nodes_plot[:, 0], nodes_plot[:, 1], c=out_plot, cmap='viridis', s=40, norm=norm_go)
            axs[j,2].set_title('Prediction g_pred(x)')
            axs[j,2].axis('equal')
            fig.colorbar(scatter_out, ax=axs[j,2])

            axs[j,3].plot(nodes_plot[:, 0], nodes_plot[:, 1], color='blue', alpha=0.5)
            scatter_error = axs[j,3].scatter(nodes_plot[:, 0], nodes_plot[:, 1], c=error_plot, cmap='coolwarm', s=40, norm=norm_err)
            axs[j,3].set_title(f'Error, loss = {round(test_rel_l2,5)}')
            axs[j,3].axis('equal')
            fig.colorbar(scatter_error, ax=axs[j,3])

        plt.tight_layout()
        if not os.path.exists(save_figure_path):
            os.makedirs(save_figure_path)
        fig.savefig(save_figure_path + f'test_{index_plot}.png', format='png', bbox_inches='tight') 
        plt.close(fig)

n_test = 1000
device = 'cuda'

with open('output.txt', 'a') as f:
    with redirect_stdout(f):

        model_path = 'E:/codes/mygithub2/scripts/curve/' + f'model/PCNO_curve_model_k8_L10.pth'
        save_figure_path = 'E:/codes/mygithub2/scripts/curve/' + f'figures/'


        k_max = 8
        ndim = 2
        L = 10

        model_train_inv_L_scale = False
        modes = compute_Fourier_modes(ndim, [k_max,k_max], [L,L])
        modes = torch.tensor(modes, dtype=torch.float).to(device)
        model = PCNO(ndim, modes, nmeasures=1,
                    layers=[128,128,128,128,128],
                    fc_dim=128,
                    in_dim=x_all.shape[-1], out_dim=y_all.shape[-1],
                    inv_L_scale_hyper = [model_train_inv_L_scale, 0.5, 2.0],
                    act='gelu').to(device)

        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint)

        print(f'\n\nNew model : {model_path}', flush = True)
        tuple_all = (x_all,y_all,node_mask,nodes,node_weights,directed_edges,edge_gradient_weights)
        test_tuple, y_normalizer = test_normalizer(tuple_all, config, n_train)
        # rel_l2, index = test_model(model, n_test, test_tuple, y_normalizer)
        # average_loss_list,index_3 = sorted_result(rel_l2, index)

        test_index = [925,952,999]
        myplot(test_index,save_figure_path,
                model,test_tuple, y_normalizer)



In [79]:
# loss_exponential = np.array(rel_l2_exponential)
# loss_linear = np.array(rel_l2_linear)
# loss_uniform = np.array(rel_l2_uniform)
# np.savez('test_result/test_losses.npz', loss_exponential=loss_exponential, loss_linear=loss_linear, loss_uniform = loss_uniform)