In [6]:
# these modules are requried to import my code
import  sys
import  os
from    pathlib import Path

# this adds my code to the system path so it can be imported as if it were a python module
cwd         = os.getcwd()
geograph_interpolator_root  = Path(cwd).__str__()
sys.path.append(geograph_interpolator_root)

# import our package
from geograph_interpolator import *

In [9]:
# load the repo demo data
node_file_name      = geograph_interpolator_root+'\Data\graph_features.csv'
edge_file_name_sim  = geograph_interpolator_root+'\Data\graph_edges_similar.csv'
edge_file_name_lat  = geograph_interpolator_root+'\Data\graph_edges_lattice.csv'

core_node_list  = load_graph_nodes(node_file_name)
lattice_edges   = load_graph_edges(edge_file_name_lat)
similar_edges   = load_graph_edges(edge_file_name_sim)

In [None]:
# visualise the different node parameters a list can be seen here: core_node_list[0].params.keys()
# role - either training or testing
# labl - label
# cond - pretend conductivity
visualise_3D(core_node_list,'cond',similar_edges)

In [None]:
# declare global variables for later visualisation
global graph
global gcn_logits
global wgcnlogits
global gat_logits

feats = ['data','zloc_n'] # feats can be a list of keys to extract from graph.ndata or number for random embedding
insize = len(feats) if isinstance(feats,list) else feats

# construct our graph objects
graph = create_dgl_graph(core_node_list,lattice_edges)
graph0 = create_dgl_graph(core_node_list,lattice_edges)
graph1 = create_dgl_graph(core_node_list,lattice_edges)
graph2 = create_dgl_graph(core_node_list,lattice_edges)

# generate our GNN Models
model_GCN = GCN(insize,5,2)
modelwGCN = wGCN(insize,5,2)
model_GAT = GAT(insize,5,2,5)

# run interpolation using the different models
gcn_loss, gcn_logits = gnn_interpolate(model_GCN, graph0, node_feats=feats, epochs=300)
wgcnloss, wgcnlogits = gnn_interpolate(modelwGCN, graph1, node_feats=feats, epochs=300)
gat_loss, gat_logits = gnn_interpolate(model_GAT, graph2, node_feats=feats, epochs=300)

In [None]:
# visualise training loss
plt.title('Loss Value')
plt.plot(gat_loss,label='GAT')
plt.plot(gcn_loss,label='GCN')
plt.plot(wgcnloss,label='wGCN')
plt.legend()

In [None]:
# generates a prediction and opens a pyvista slider to move around results
plotter = pv.Plotter(shape=(1, 3),notebook=False,window_size=(1920,1000))

def create_mesh(value):
    res = int(value)

    x = graph.ndata['xloc'].flatten().detach().numpy()
    y = graph.ndata['yloc'].flatten().detach().numpy()
    z = graph.ndata['zloc'].flatten().detach().numpy()
    p0 = torch.argmax(gcn_logits[res],dim=1)
    p1 = torch.argmax(wgcnlogits[res],dim=1)
    p2 = torch.argmax(gat_logits[res],dim=1)

    # wow a pyvista object
    vert = list(zip(x,y,z))

    gcn_cloud = pv.PolyData(vert)
    gcn_cloud['pred'] = p0    

    wgcncloud = pv.PolyData(vert)
    wgcncloud['pred'] = p1  

    gat_cloud = pv.PolyData(vert)
    gat_cloud['pred'] = p2   

    # GCN
    plotter.subplot(0, 0)
    plotter.add_text("GCN", font_size=30)
    plotter.add_mesh(gcn_cloud,name='gcn_pointcloud')
    
    # wGCN
    plotter.subplot(0, 1)
    plotter.add_text("wGCN", font_size=30)
    plotter.add_mesh(wgcncloud,name='wgcnpointcloud')

    # GAT
    plotter.subplot(0, 2)
    plotter.add_text("GAT", font_size=30)
    plotter.add_mesh(gat_cloud, name='gat_pointcloud')

    return


plotter.add_slider_widget(create_mesh, [0, 299], title='Epoch')
plotter.show()