In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import torch_geometric as pyg
import torch
import pandas as pd
from torch_geometric.utils import to_undirected
import networkx as nx
from utils import dataset_gen, minmaxscaler
import seaborn as sns
from model import GNN
import os
from torch_geometric.loader import DataLoader
from IPython.display import clear_output

device = torch.device('cpu')

In [3]:
indicators = ['WSS','OSI']#

sub_ind = [indicators[0]]

conv_list = ['gconv','gin','gcn', 'gtr']
long_conv_list = ['GraphConv', 'GIN', 'GCN', 'Graph Transformer']
input_size = 7 # dataset features: Time, Press_SA, Press_abd, FlowRate, coord(x), coord(y), coord(z)
hidden_size = 32
num_layers = 3
perc = 75
output_size = len(sub_ind)
dataset= dataset_gen(perc,sub_ind)
num_nodes = dataset[0].num_nodes
T = len(dataset)
loader = DataLoader(dataset, batch_size = T, shuffle=False)

y = torch.zeros((T,num_nodes))
for i, data in enumerate(loader):
    y = data.y.reshape((T,num_nodes))
t = torch.linspace(0,1,T)






In [4]:
pred_dict = {}
task_type = 'interp'
for conv_type in conv_list:
    pred_dict[conv_type] = {}
    for it in range(5):
        pred = torch.zeros((T,num_nodes))
        
        log_name = '_'.join([conv_type, 'loss'])
        log_path = 'interp/' + '_'.join(['architecture']+sub_ind)

        model = GNN(input_size,hidden_size,output_size,num_layers,conv_type=conv_type,device=device)

        model.load_state_dict(torch.load(log_path+'/'+log_name + str(it)+'.pt', map_location=torch.device('cpu')))

        

        for i, data in enumerate(loader):
            
            pred[:,:] = model(data.x, data.edge_index).reshape((T,num_nodes))
        pred_dict[conv_type][str(it)] = pred

In [2]:
nodes = [208,305,400,800,852]
data_raw = []
for i, conv_type in enumerate(conv_list):
    for it in range(5):       
        pred = pred_dict[conv_type][str(it)]
        for node in nodes:
            for j in range(500):
                data_raw.append({'conv_type': conv_type, 'it': it, 'node': node, 't':float(t[j]), 'pred': float(pred[j,node])})
df = pd.DataFrame.from_records(data_raw)

NameError: name 'conv_list' is not defined

In [1]:

for node in nodes:
    fig, axes = plt.subplots(1,4, sharey=True, figsize=(16, 8) )
    axes[0].set_ylabel(sub_ind[0], fontsize = '16')
    for i,conv_type in enumerate(conv_list):
        red = df[df['conv_type']==conv_type][df['node']==node][['t','pred','it']]
        axes[i].set_xlabel('t', fontsize = '16')
        axes[i].tick_params(axis = 'both', which='minor', labelsize= '14')
        sns.lineplot(data = red, x="t", y="pred", ax=axes[i], label = 'Prediction')
        sns.lineplot(x = t, y = y[:,node].detach(),ax=axes[i], label = 'Ground Truth')
        axes[i].legend(fontsize = '14', loc = 'upper right')
        axes[i].set_title(long_conv_list[i])

    os.makedirs('seaborn/pred_layers_architecture/', exist_ok=True)
    plt.savefig('seaborn/pred_layers_architecture/pred'+str(node)+'.pdf')
    plt.close()

NameError: name 'nodes' is not defined

## TAWSS comparison

In [29]:
errors = []
T = len(dataset)
dt = 1/T
data = minmaxscaler(np.load(str(perc)+'Percent.npz')['WSS'])
data[0,:] = data[0,:]/2
data[-1,:] = data[-1,:]/2
TAWSS = dt * np.sum(np.abs(data), axis = 0)

for j, conv_type in enumerate(conv_list):
    for it in range(5):
        pred =  pred_dict[conv_type][str(it)]
        if torch.is_tensor(pred):
            pred = pred.detach().numpy()
        pred[0,:] /= 2
        pred[-1,:] /=2
        TAWSSp = dt * np.sum(np.abs(pred), axis = 0)
        error = np.linalg.norm(TAWSS-TAWSSp,ord= np.inf)/np.linalg.norm(TAWSS, ord = np.inf)
        errors.append({'conv_type':conv_type, 'it':it, 'error': error})

In [30]:
errors

[{'conv_type': 'gconv', 'it': 0, 'error': 0.4359148527161412},
 {'conv_type': 'gconv', 'it': 1, 'error': 0.42830641615926596},
 {'conv_type': 'gconv', 'it': 2, 'error': 0.4122823528838058},
 {'conv_type': 'gconv', 'it': 3, 'error': 0.3981820581869405},
 {'conv_type': 'gconv', 'it': 4, 'error': 0.46209571022814955},
 {'conv_type': 'gin', 'it': 0, 'error': 0.5572452436722845},
 {'conv_type': 'gin', 'it': 1, 'error': 0.5041604873554331},
 {'conv_type': 'gin', 'it': 2, 'error': 0.5325251521964138},
 {'conv_type': 'gin', 'it': 3, 'error': 0.48339074700129553},
 {'conv_type': 'gin', 'it': 4, 'error': 0.5429610762798435},
 {'conv_type': 'gcn', 'it': 0, 'error': 0.5102148760185666},
 {'conv_type': 'gcn', 'it': 1, 'error': 0.37043010903231843},
 {'conv_type': 'gcn', 'it': 2, 'error': 0.31473753041463026},
 {'conv_type': 'gcn', 'it': 3, 'error': 0.40728635943200714},
 {'conv_type': 'gcn', 'it': 4, 'error': 0.4079282080598775},
 {'conv_type': 'gtr', 'it': 0, 'error': 0.42026786599561283},
 {'conv

In [31]:
db = pd.DataFrame.from_records(errors)

In [32]:
db.groupby('conv_type').describe()

Unnamed: 0_level_0,it,it,it,it,it,it,it,it,error,error,error,error,error,error,error,error
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max
conv_type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
gcn,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.402119,0.071388,0.314738,0.37043,0.407286,0.407928,0.510215
gconv,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.427356,0.024287,0.398182,0.412282,0.428306,0.435915,0.462096
gin,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.524057,0.029917,0.483391,0.50416,0.532525,0.542961,0.557245
gtr,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.422558,0.042028,0.352242,0.420268,0.431585,0.451797,0.456901
