In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import torch_geometric as pyg
import torch
import pandas as pd
from torch_geometric.utils import to_undirected
import networkx as nx
from utils import dataset_gen, dataset_gen_notscaled, minmaxscaler
import seaborn as sns
from aneurysm_interp import GNN
import os
from torch_geometric.loader import DataLoader
from IPython.display import clear_output

device = torch.device('cpu')

In [3]:
indicators = ['WSS','OSI']#

sub_ind = [indicators[0]]

conv_list = ['gconv','gin','gcn', 'gtr']
long_conv_list = ['GraphConv', 'GIN', 'GCN', 'Graph Transformer']
input_size = 7 # dataset features: Time, Press_SA, Press_abd, FlowRate, coord(x), coord(y), coord(z)
hidden_size = 32
num_layers = 3
perc = 75
output_size = len(sub_ind)
dataset= dataset_gen(perc,sub_ind)
num_nodes = dataset[0].num_nodes
T = len(dataset)
loader = DataLoader(dataset, batch_size = 1, shuffle=False)

y = torch.zeros((T,num_nodes))
for i, data in enumerate(loader):
    y[i,:] = data.y.squeeze()
t = torch.linspace(0,1,T)






In [29]:
pred = torch.zeros((len(conv_list),T,num_nodes))
best_it = [3, 1, 2, 2]
for j,conv_type in enumerate(conv_list):
        task_type = 'interp'
        log_name = '_'.join([conv_type, 'loss'])
        log_path = 'interp/' + '_'.join(['architecture']+sub_ind)

        model = GNN(input_size,hidden_size,output_size,num_layers,conv_type=conv_type,device=device)

        model.load_state_dict(torch.load(log_path+'/'+log_name + str(best_it[j])+'.pt', map_location=torch.device('cpu')))

        

        for i, data in enumerate(loader):
            
            pred[j,i,:] = model(data.x, data.edge_index).squeeze()

In [31]:
folder = '_'.join(['pred', 'architecture', sub_ind[0]])
if not os.path.exists(folder):
    os.makedirs(folder)

for node in range(500, 1500):


    # for node in range(num_nodes):

    fig = plt.figure(figsize=(16, 8))
    ax1 = fig.add_subplot(141)
    ax1.plot(t, pred[0,:,node].detach(), label='prediction')
    ax1.plot(t, y[:,node].detach(), label='ground_truth')
    ax1.set_xlabel('Timestamp t')
    ax1.set_ylabel('Wall Shear Stress')
    ax1.set_title(long_conv_list[0])
    ax1.legend()

    ax2 = fig.add_subplot(142, sharey=ax1)
    ax2.plot(t, pred[1,:,node].detach(), label='prediction')
    ax2.plot(t, y[:,node].detach(), label='ground_truth')
    ax2.set_xlabel('Timestamp t')
    ax2.set_title(long_conv_list[1])
    ax2.legend()

    ax3 = fig.add_subplot(143,sharey=ax1)
    ax3.plot(t, pred[2,:,node].detach(), label='prediction')
    ax3.plot(t, y[:,node].detach(), label='ground_truth')
    ax3.set_xlabel('Timestamp t')
    ax3.set_title(long_conv_list[2])
    ax3.legend()

    ax4 = fig.add_subplot(144,sharey=ax1)
    ax4.plot(t, pred[3,:,node].detach(), label='prediction')
    ax4.plot(t, y[:,node].detach(), label='ground_truth')
    ax4.set_xlabel('Timestamp t')
    ax4.set_title(long_conv_list[3])
    ax4.legend()

    # plt.show()

    plt.savefig(folder+ '/prediction_node_'+str(node)+'.pdf')
    plt.close()

## TAWSS comparison

In [11]:
errors = []
T = len(dataset)
dt = 1/T
data = minmaxscaler(np.load(str(perc)+'Percent.npz')['WSS'])
data[0,:] = data[0,:]/2
data[-1,:] = data[-1,:]/2
TAWSS = dt * np.sum(np.abs(data), axis = 0)
for it in range(5):
    pred = torch.load('pred'+str(it)+'.pt').detach()
    for j in range(len(conv_list)):
        pred[j,0,:] /= 2
        pred[j,-1,:] /=2
        TAWSSp = dt * np.sum(np.abs(pred[j,:,:].detach().numpy()), axis = 0)
        error = np.linalg.norm(TAWSS-TAWSSp,ord= np.inf)/np.linalg.norm(TAWSS, ord = np.inf)
        errors.append({'conv_type':conv_list[j], 'it':it, 'error': error})

In [12]:
errors

[{'conv_type': 'gconv', 'it': 0, 'error': 0.4355913942667632},
 {'conv_type': 'gin', 'it': 0, 'error': 0.9134659286021182},
 {'conv_type': 'gcn', 'it': 0, 'error': 0.5099459476356111},
 {'conv_type': 'gtr', 'it': 0, 'error': 0.42014311047339825},
 {'conv_type': 'gconv', 'it': 1, 'error': 0.4280055703066529},
 {'conv_type': 'gin', 'it': 1, 'error': 0.830777648952503},
 {'conv_type': 'gcn', 'it': 1, 'error': 0.37028790248433957},
 {'conv_type': 'gtr', 'it': 1, 'error': 0.45688621261790713},
 {'conv_type': 'gconv', 'it': 2, 'error': 0.4120113177371765},
 {'conv_type': 'gin', 'it': 2, 'error': 0.9061535362323161},
 {'conv_type': 'gcn', 'it': 2, 'error': 0.31457376465172326},
 {'conv_type': 'gtr', 'it': 2, 'error': 0.35218609742119683},
 {'conv_type': 'gconv', 'it': 3, 'error': 0.3979247521169186},
 {'conv_type': 'gin', 'it': 3, 'error': 1.5180260158161583},
 {'conv_type': 'gcn', 'it': 3, 'error': 0.40709692626500815},
 {'conv_type': 'gtr', 'it': 3, 'error': 0.4315511953118443},
 {'conv_typ

In [13]:
db = pd.DataFrame.from_records(errors)

In [26]:
db.groupby('conv_type').describe()

Unnamed: 0_level_0,it,it,it,it,it,it,it,it,error,error,error,error,error,error,error,error
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,std,min,25%,50%,75%,max
conv_type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
gcn,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.401948,0.071347,0.314574,0.370288,0.407097,0.407834,0.509946
gconv,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.427096,0.024321,0.397925,0.412011,0.428006,0.435591,0.461947
gin,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,1.013493,0.283974,0.830778,0.899041,0.906154,0.913466,1.518026
gtr,5.0,2.0,1.581139,0.0,1.0,2.0,3.0,4.0,5.0,0.422519,0.042054,0.352186,0.420143,0.431551,0.45183,0.456886
