# Import libraries, define utility functions

In [None]:
import pickle
import sys
import os

import torch
from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
import torch_geometric as tg

import funcs_helpers as fh

In [None]:
diameter = 0.9

# path to get DataFrame
df_path = f"data\\coarseMesh_noBifurcation_5\\dataframe_coarseMesh_noBifurcation_diameter_{diameter}_5.pkl"

# path to put graph data
gr_path = f'data\\coarseMesh_noBifurcation_5\\graphs_coarseMesh_noBifurcation_diameter_{diameter}_5_noBulkNodes_4.pkl'
# f'data\\coarseMesh\\graphs_coarseMesh_diameter_{diameter}_fixed.pkl'

mesh_path = mesh_path = r"data\coarseMesh_noBifurcation_5\mesh_info.mat"

# whether the mesh (connectivity AND initial node positions!) is always the same or if it varies per trajectory
constant_mesh = True #  False  #

remove_bulk_nodes = True #False

if not constant_mesh:
    raise NotImplementedError('non-constant mesh not implemented')

# Open DataFrame

In [None]:
with open(df_path, 'rb') as f:
    df = pickle.load(f)
df

# Remove NaNs

In [None]:
print('df.shape before:', df.shape)
W = np.array(df['W'])
bools = np.isnan(W)
print(sum(bools), 'cases where W is NaN')
df = df[~bools]
print('df.shape after:', df.shape)

# Remove duplicate F

In [None]:
# find unique F
F = np.stack(df['F'], axis=0)
print('F.shape:', F.shape)
unique_F, inds_to_keep = np.unique(F, return_index=True, axis=0)
print('unique_F.shape:', unique_F.shape)

print('df.shape before:', df.shape)
df = df.iloc[inds_to_keep]
print('df.shape after:', df.shape)

In [None]:
with open(df_path + '_reduced2.pkl', 'wb') as f:
    pickle.dump(df, f)

# Mesh

In [None]:
with open(df_path + '_reduced2.pkl', 'rb') as f:
    df = pickle.load(f)

In [None]:
# dict to store data associated with edges
e_data = {}

# dict to store data associated with nodes
n_data = {}

# dict to store any indices into the edges
e_inds = {}

# dict to store any indices into the nodes
n_inds = {}

# dict to store graph-level data
g_data = {}

In [None]:
# import one mesh (everything except position is the same for all meshes)
file = mesh_path
mesh_vars = loadmat(file)

# # check which mesh variables there are
for key in mesh_vars:
    print(key)

if constant_mesh:
    n_data['pos'] = mesh_vars['p']
else:
    n_data['pos'] = np.stack(df['pos'].values)

# triangle definitions
n_inds['t'] = mesh_vars['t'].astype(int)-1

# indices of boundary nodes (-1 because of matlab indexing)
n_inds['b_bottom'] = mesh_vars['IDGamma_1'][0]-1
n_inds['b_right'] = mesh_vars['IDGamma_2'][0]-1
n_inds['b_top'] = mesh_vars['IDGamma_3'][0]-1
n_inds['b_left'] = mesh_vars['IDGamma_4'][0]-1

# indices of corner nodes
n_inds['dependent_corner_inds'] = np.asarray([mesh_vars[key][0,0]-1 for key in
    ['id2', 'id3', 'id4']])
n_inds['fixed_corner_ind'] = mesh_vars['id1'][0,0]-1

# indices of hole boundary nodes
# n_inds['hole_boundary'] = mesh_vars['hole_boundaries']-1


In [None]:
 # extract boundary nodes
tol_g = 1e-6
x_min, x_max = np.min(n_data['pos'][0]), np.max(n_data['pos'][0])
y_min, y_max = np.min(n_data['pos'][1]), np.max(n_data['pos'][1])
n_inds['b_bottom'] = np.where(n_data['pos'][1] < y_min + tol_g)[0]
n_inds['b_right'] = np.where(n_data['pos'][0] > x_max - tol_g)[0]
n_inds['b_top'] = np.where(n_data['pos'][1] > y_max - tol_g)[0]
n_inds['b_left'] = np.where(n_data['pos'][0] < x_min + tol_g)[0]

# sort indices
n_inds['b_bottom'] = n_inds['b_bottom'][np.argsort(n_data['pos'][0][n_inds['b_bottom']])]
n_inds['b_right'] = n_inds['b_right'][np.argsort(n_data['pos'][1][n_inds['b_right']])]
n_inds['b_top'] = n_inds['b_top'][np.argsort(n_data['pos'][0][n_inds['b_top']])]
n_inds['b_left'] = n_inds['b_left'][np.argsort(n_data['pos'][1][n_inds['b_left']])]

# extract dependent corner nodes
n_inds['dependent_corner_inds'] = np.array([
    np.where((n_data['pos'][1] < y_min + tol_g)*(n_data['pos'][0] > x_max - tol_g))[0][0],
    np.where((n_data['pos'][1] > y_max - tol_g)*(n_data['pos'][0] > x_max - tol_g))[0][0],
    np.where((n_data['pos'][1] > y_max - tol_g)*(n_data['pos'][0] < x_min + tol_g))[0][0],
])

# extract independent corner node
n_inds['fixed_corner_ind'] = np.where((n_data['pos'][1] < y_min + tol_g)*(n_data['pos'][0] < x_min + tol_g))[0][0]

print(n_inds['fixed_corner_ind'])
print(n_inds['dependent_corner_inds'])

In [None]:
# add all other data from df to dicts
n_data['U_arr'] = np.array([np.reshape(elem, (-1, 2)).T for elem in df['U']])

# put FE results in the right shape to compare them (and convert from Matlabs column-major order)
g_data['F'] = np.stack(df['F'])
g_data['W'] = df['W'].values[..., np.newaxis]
g_data['P'] = np.stack(df['P'].values, axis=0)[:, [0, 2, 1, 3]].reshape(-1, 2, 2)
g_data['D'] = np.stack(df['D'].values, axis=0)
g_data['traj'] = df['trajectory'].values[..., np.newaxis]

In [None]:
for key in g_data:
    print(f'{key:10} {g_data[key].shape}')

# Edges
Turn triangles (quadratic elements) into edges

In [None]:
# use all nodes of the quadratic elements or only the corner nodes
corner_nodes_only = False


In [None]:
# triangles are described by first their corner nodes, then the nodes in the middle of the sides

if corner_nodes_only:
    # use only corner nodes
    inds = ([[0, 1, 2],
            [1, 2, 0]],)
else:
    # use all nodes
    inds = ([[0, 3, 1, 4, 2, 5],
            [3, 1, 4, 2, 5, 0]],)

edge_index = n_inds['t'][inds].reshape(2, -1)
print(edge_index.shape)

In [None]:
# sort such that edge 0->1 is the same as 1->0
edge_index = np.sort(edge_index, axis=0)

# remove duplicate edges
print(edge_index.shape)
edge_index, counts = np.unique(edge_index, axis=1, return_counts=True)
print(edge_index.shape)

# index into edge_index of edges at boundaries (sides or holes)
# (edges at boundaries are in only one element so their count=1)
e_inds['all_boundaries'] = np.where(counts == 1)[0]

In [None]:
# add these now instead of earlier, because they should not be considered boundary edges, even though they are in only one element
if not corner_nodes_only:
    # add edge_index between mid-side nodes
    inds = ([[3,4,5],
            [4,5,3]],)

    edge_index2 = n_inds['t'][inds].reshape(2, -1)
    print(edge_index.shape)
    print(edge_index2.shape)
    edge_index = np.concatenate((edge_index, edge_index2), axis=-1)
    print(edge_index.shape)

e_data['edge_index'] = edge_index

In [None]:
## Plot mesh to check
%matplotlib qt

# plot all nodes in original location
if constant_mesh:
    pos_temp = n_data['pos']
    edge_index_temp = e_data['edge_index']
else:
    pos_temp = n_data['pos'][0]
    edge_index_temp = e_data['edge_index'][0]
plt.scatter(*pos_temp, c='black', s=1)  #, s=50, alpha=0.5)

# plot edges
x, y = pos_temp.T[edge_index_temp].T
x, y = x.T, y.T
plt.plot(x, y, c='black', alpha=0.3, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


In [None]:
np.unique(np.sort(e_data['edge_index'], axis=0), axis=1).shape

In [None]:
e_data['edge_index'].shape

In [None]:
n_data['pos'].shape

# Remove unused nodes

In [None]:
# only keep nodes that are in edge_index, remove other nodes, relabel edges
# (only makes a difference when only using corner nodes, then mid-edge nodes are removed)
nodes_to_keep, edge_index = np.unique(e_data['edge_index'], return_inverse=True)
edge_index = edge_index.reshape(2, -1)

for key in n_inds:
    # find index into nodes_to_keep of the nodes in n_inds
    temp = np.searchsorted(nodes_to_keep, n_inds[key])

    # keep only the ones that are in nodes_to_keep
    n_inds[key] = n_inds[key][n_inds[key] == nodes_to_keep[temp]]

for key in n_data:
    n_data[key] = n_data[key][..., nodes_to_keep]

# Calculate positions (original, affine, final) and displacements

In [None]:
# positions of nodes after affine transformation
n_data['pos_affine'] = np.matmul(g_data['F'], n_data['pos'])

# original vectors along edges
e_data['r'] = n_data['pos'][..., e_data['edge_index'][1]] - n_data['pos'][..., e_data['edge_index'][0]]

# vectors along edges after affine transformation
# e_data['r_affine'] = n_data['pos_affine'][..., e_data['edge_index'][1]] - n_data['pos_affine'][..., e_data['edge_index'][0]]

# length of edges
e_data['d'] = np.linalg.norm(e_data['r'], axis=-2)
# e_data['d_affine'] = np.linalg.norm(e_data['r_affine'], axis=-2)

# final position of nodes
n_data['pos_final'] = n_data['pos']+ n_data['U_arr']

# # affine displacement of nodes
# n_data['U_affine'] = n_data['pos_affine'] - n_data['pos']

# # periodic displacement of nodes
# n_data['w'] = n_data['U_arr'] - n_data['U_affine']


In [None]:
## Plot meshes
%matplotlib qt

# plot all nodes in original location
plt.scatter(*(n_data['pos']), c='tab:blue', s=1)

# plot edges
x, y = n_data['pos'].T[e_data['edge_index']].T
x, y = x.T, y.T
plt.plot(x, y, c='tab:blue', alpha=0.3, zorder=-1)


# plot all nodes in affine location
plt.scatter(*(n_data['pos_affine'][0]), c='tab:orange', s=1)  #, s=50, alpha=0.5)

# plot edges
x, y = n_data['pos_affine'][0].T[e_data['edge_index']].T
x, y = x.T, y.T
plt.plot(x, y, c='tab:orange', alpha=0.3, zorder=-1)


# plot all nodes in final location
plt.scatter(*(n_data['pos_final'][0]), c='tab:green', s=1)  #, s=50, alpha=0.5)

# plot edges
x, y = n_data['pos_final'][0].T[e_data['edge_index']].T
x, y = x.T, y.T
plt.plot(x, y, c='tab:green', alpha=0.3, zorder=-1)



# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


In [None]:
for stuff in [e_data, n_data, g_data, e_inds, n_inds]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
%matplotlib qt

# plot all nodes in original location
# plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot nodes on hole bound
plt.scatter(*(n_data['pos']), s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = n_data['pos'][..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# # plot boundary edges
# b_edge_index = e_data['edge_index'][..., e_inds['hole_boundary']]
# pos1 = n_data['pos'][..., b_edge_index[0]]
# pos2 = pos1 + e_data['r'][..., e_inds['hole_boundary']]
# x = np.stack((pos1[0], pos2[0]))
# y = np.stack((pos1[1], pos2[1]))
# plt.plot(x, y, c='green', zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()

# Find boundaries holes

In [None]:
n_inds['all_boundaries'] = np.unique(e_data['edge_index'][..., e_inds['all_boundaries']])

sides_boundary_nodes = np.concatenate((n_inds['b_bottom'], n_inds['b_top'], n_inds['b_left'], n_inds['b_right'], n_inds['dependent_corner_inds'], n_inds['fixed_corner_ind']))
n_inds['sides'] = np.unique(sides_boundary_nodes)

In [None]:
n_inds['hole_boundary'] = np.setdiff1d(n_inds['all_boundaries'], n_inds['sides'])

In [None]:
# split up boundary edges, into sides and hole boundary
e_inds['sides'] = np.where(np.isin(e_data['edge_index'], n_inds['sides']).all(axis=0))[0]
e_inds['hole_boundary'] = np.where(np.isin(e_data['edge_index'], n_inds['hole_boundary']).all(axis=0))[0]

In [None]:
## Plot mesh to check boundary nodes
%matplotlib qt

# plot all nodes in original location
plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot edges
x, y = n_data['pos'].T[e_data['edge_index']].T
x, y = x.T, y.T
plt.plot(x, y, c='black', alpha=0.3, zorder=-1)

# plot nodes on hole boundary
plt.scatter(*n_data['pos'][:, n_inds['hole_boundary']], marker='x', s=50)

# plot boundary edges
x, y = n_data['pos'].T[e_data['edge_index']].T[:, e_inds['hole_boundary']]
x, y = x.T, y.T
plt.plot(x, y, c='red', alpha=1, zorder=-1)

# plot nodes on sides
plt.scatter(*n_data['pos'][:, n_inds['sides']], marker='x', s=50)

# plot boundary sides
x, y = n_data['pos'].T[e_data['edge_index']].T[:, e_inds['sides']]
x, y = x.T, y.T
plt.plot(x, y, c='green', alpha=1, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


In [None]:
%matplotlib qt

# plot all nodes in original location
# plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot nodes on hole bound
plt.scatter(*(n_data['pos']), s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = n_data['pos'][..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# plot boundary edges
b_edge_index = e_data['edge_index'][..., e_inds['hole_boundary']]
pos1 = n_data['pos'][..., b_edge_index[0]]
pos2 = pos1 + e_data['r'][..., e_inds['hole_boundary']]
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='green', zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()

# Remove configurations with overlapping elements

In [None]:
# check overlap between edges

bad_arr = []  # array for indices of loadcases with overlap

# iterate over all loadcases
for i, pos_temp in enumerate(n_data['pos_final']):
    print(i, end=' ')
    # pos_temp = pos_temp[:, :, 0].T
    pos_temp = pos_temp.T

    # coordinates of boundary edges
    boundary_edges = e_data['edge_index'][..., e_inds['hole_boundary']].T
    n_b_edges = boundary_edges.shape[0]
    temp = pos_temp[boundary_edges]

    # create all possible pairs of boundary edges
    points1 = np.repeat(temp[np.newaxis, ...], n_b_edges, axis=0)
    points2 = np.repeat(temp[:, np.newaxis, ...], n_b_edges, axis=1)
    points = np.concatenate((points1, points2), axis=2)

    # points1.shape = points2.shape = [nr of b edges, nr of b edges, 2, 2]
    # points.shape = [nr of b edges, nr of b edges, 4, 2]

    # check if a pair of edges intersects
    bools = fh.intersect(points)

    # exclude edges compared with themselves
    temp2 = np.arange(n_b_edges)
    bools[temp2, temp2] = False

    if bools.any():
        print('bad graph!', g_data['F'][i])

        # plot all nodes
        plt.figure()
        plt.scatter(*(pos_temp.T), s=2)

        # plot boundary edges
        x, y = pos_temp[e_data['edge_index']].T[:, e_inds['hole_boundary']]
        x, y = x.T, y.T
        plt.plot(x, y, c='tab:blue', alpha=1, zorder=-1)

        plt.gca().set_aspect('equal')
        plt.grid()

        bad_arr.append(i)



In [None]:
# save all the figures that were just created in one pdf

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

def multipage(filename, figs=None, dpi=200):
    pp = PdfPages(filename)
    if figs is None:
        figs = [plt.figure(n) for n in plt.get_fignums()]
    for fig in figs:
        fig.savefig(pp, format='pdf')
    pp.close()

multipage('loadcases_with_overlap.pdf')
plt.close('all')

In [None]:
# remove all cases with overlap from the data

n_loadcases = len(g_data['F'])

bools = np.ones(n_loadcases, dtype=bool)
bools[bad_arr] = False

for stuff in [e_data, n_data, g_data, n_inds, e_inds]:
    for key in stuff:
        # check if this is a quantity that has one entry per loadcase
        if stuff[key].shape[0] == n_loadcases:
            stuff[key] = stuff[key][bools]

In [None]:
for stuff in [e_data, n_data, g_data, e_inds, n_inds]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
dir_results, dataframe_file = os.path.split(df_path)
path = os.path.join(dir_results,
                    'data_'
                    + dataframe_file.split('_', maxsplit=1)[1])
with open(path, 'wb') as f:
    pickle.dump([e_data, n_data, g_data, e_inds, n_inds], f)

# Merge boundary nodes

In [None]:
dir_results, dataframe_file = os.path.split(df_path)
path = os.path.join(dir_results,
                    'data_'
                    + dataframe_file.split('_', maxsplit=1)[1])
with open(path, 'rb') as f:
    e_data, n_data, g_data, e_inds, n_inds = pickle.load(f)

In [None]:
for stuff in [e_data, n_data, g_data, e_inds, n_inds]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
# merge right and lower boundary nodes with their counterparts
fh.replace(e_data['edge_index'], n_inds['b_top'], n_inds['b_bottom'], inplace=True)
fh.replace(e_data['edge_index'], n_inds['b_right'], n_inds['b_left'], inplace=True)
fh.replace(e_data['edge_index'], n_inds['dependent_corner_inds'], 3*[n_inds['fixed_corner_ind'][0]], inplace=True)

In [None]:
# remove duplicate edges

# sort such that edge 0->1 is the same as 1->0
temptemp = np.sort(e_data['edge_index'], axis=0)

# remove duplicate edges
_, inds = np.unique(temptemp, axis=1, return_index=True)
print(f'{e_data["edge_index"].shape[-1]} edges, {len(inds)} unique')

# indices of edges to keep
inds = np.sort(inds)

for key in e_data:
    e_data[key] = e_data[key][..., inds]
for key in e_inds:
    # find index into nodes_to_keep of the nodes in n_inds
    new_inds = np.searchsorted(inds, e_inds[key])

    # keep only the ones that are in nodes_to_keep
    e_inds[key] = new_inds[inds[new_inds] == e_inds[key]]

# remove unused nodes and relabel the rest
nodes_to_keep, e_data['edge_index'] = np.unique(e_data['edge_index'], return_inverse=True)
e_data['edge_index'] = e_data['edge_index'].reshape(2, -1)

for key in n_inds:
    # find index into nodes_to_keep of the nodes in n_inds
    new_inds = np.searchsorted(nodes_to_keep, n_inds[key]) % len(nodes_to_keep)

    # keep only the ones that are in nodes_to_keep
    n_inds[key] = new_inds[nodes_to_keep[new_inds] == n_inds[key]]

for key in n_data:
    n_data[key] = n_data[key][..., nodes_to_keep]

In [None]:
## Plot mesh to check boundary nodes
%matplotlib qt

# plot all nodes in original location
plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot edges
x, y = n_data['pos'].T[e_data['edge_index']].T
x, y = x.T, y.T
plt.plot(x, y, c='black', alpha=0.3, zorder=-1)

# plot nodes on hole boundary
plt.scatter(*n_data['pos'][:, n_inds['hole_boundary']], marker='x', s=50)

# plot boundary edges
x, y = n_data['pos'].T[e_data['edge_index']].T[:, e_inds['hole_boundary']]
x, y = x.T, y.T
plt.plot(x, y, c='red', alpha=1, zorder=-1)

# plot nodes on sides
plt.scatter(*n_data['pos'][:, n_inds['sides']], marker='x', s=50)

# plot boundary sides
x, y = n_data['pos'].T[e_data['edge_index']].T[:, e_inds['sides']]
x, y = x.T, y.T
plt.plot(x, y, c='green', alpha=1, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


In [None]:
n_data['pos'].shape

In [None]:
e_data['edge_index'].shape

In [None]:
%matplotlib qt

# plot all nodes in original location
# plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot nodes on hole bound
plt.scatter(*(n_data['pos']), s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = n_data['pos'][..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# plot boundary edges
b_edge_index = e_data['edge_index'][..., e_inds['hole_boundary']]
pos1 = n_data['pos'][..., b_edge_index[0]]
pos2 = pos1 + e_data['r'][..., e_inds['hole_boundary']]
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='green', zorder=-1)

# plot copy
# plot nodes on hole bound
pos_temp = n_data['pos'] + [[3.2], [0]]
plt.scatter(*(pos_temp), s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = pos_temp[..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()

# Remove bulk nodes

In [None]:
if remove_bulk_nodes:
    # remove nodes not at the hole boundaries and relabel the rest
    nodes_to_keep = np.copy(n_inds['hole_boundary'])

    for key in n_inds:
        # find index into nodes_to_keep of the nodes in n_inds
        new_inds = np.searchsorted(nodes_to_keep, n_inds[key]) % len(nodes_to_keep)

        # keep only the ones that are in nodes_to_keep
        n_inds[key] = new_inds[nodes_to_keep[new_inds] == n_inds[key]]

    for key in n_data:
        n_data[key] = n_data[key][..., nodes_to_keep]

    for key in e_data:
        e_data[key] = e_data[key][..., e_inds['hole_boundary']]

    e_inds['hole_boundary'] = np.arange(e_data['edge_index'].shape[-1])
    fh.replace(e_data['edge_index'], nodes_to_keep, np.arange(len(nodes_to_keep)), inplace=True)

In [None]:
if remove_bulk_nodes:
    pos_temp = n_data['pos'].T
    n_nodes = n_data['pos'].shape[1]

    # find quadrant of each boundary node
    n_data['quad'] = np.zeros(n_nodes, dtype=int)
    n_data['quad'][pos_temp[:, 0] > 0] += 1  # x-coordinate in right half
    n_data['quad'][pos_temp[:, 1] > 0] += 2  # y-coordinate in upper half

# Tile RVEs - old way (not exactly the same mesh tiled)

In [None]:
# # Tile RVEs
# tiling = [2, 2]
# basis_vecs = np.array([[3.2, 0], [0, 3.2]]).reshape(2, 2, 1)
# n_loadcases = len(n_data['pos_final'])

# n_data_new = {'pos': np.array([]).reshape(2, 0), 'quad': np.array([], dtype=int), 'pos_final': np.array([]).reshape(n_loadcases, 2, 0)}
# for i in range(tiling[0]):
#     for j in range(tiling[1]):
#         pos2 = n_data['pos'] + i*basis_vecs[0] + j*basis_vecs[1]
#         quad2 = n_data['quad'] + 4*(tiling[1]*i + j)
#         n_data_new['pos'] = np.concatenate((n_data_new['pos'], pos2), axis=-1)
#         n_data_new['quad'] = np.concatenate((n_data_new['quad'], quad2))

#         basis_vecs2 = np.einsum('ijk,lkm->ijkm', g_data['F'], basis_vecs)
#         pos_final2 = n_data['pos_final'] + i*basis_vecs2[:, 0] + j*basis_vecs2[:, 1]
#         n_data_new['pos_final'] = np.concatenate((n_data_new['pos_final'],
#                                                   pos_final2), axis=-1)

In [None]:
# n_RVEs = np.prod(tiling)
# n_edges = e_data['edge_index'].shape[-1]

# # all RVE copies get the same edges
# e_data['edge_index'] = np.repeat(e_data['edge_index'].reshape(2, n_edges, 1), n_RVEs, axis=-1)
# e_data['r'] = np.repeat(e_data['r'].reshape(2, n_edges, 1), n_RVEs, axis=-1).reshape(2, -1)
# e_data['d'] = np.repeat(e_data['d'].reshape(n_edges, 1), n_RVEs, axis=-1).flatten()
# e_inds['hole_boundary'] = np.repeat(e_data['edge_index'].reshape(-1, 1), n_RVEs, axis=-1)

# # but nodes they connect must be incremented
# e_data['edge_index'] += np.arange(n_RVEs)*n_edges
# e_data['edge_index'] = e_data['edge_index'].reshape(2, -1)
# e_inds['hole_boundary'] += np.arange(n_RVEs)*n_edges
# e_inds['hole_boundary'] = e_inds['hole_boundary'].flatten()


In [None]:
# n_data = n_data_new

# Create new edges

In [None]:
if remove_bulk_nodes:
    mode = 'connect3'

    n_quads = n_data['quad'].max()+1
    print(n_quads)
    edges_new = np.stack((x.flatten(), y.flatten()), axis=1)
    print(edges_new.shape)
    edges_new = []
    pos_temp = n_data['pos'].T

    if mode == 'all':
        n_nodes = pos_temp.shape[0]
        x, y = np.meshgrid(np.arange(n_nodes), np.arange(n_nodes))
        edges_new = np.stack((x.flatten(), y.flatten()), axis=1)

        # remove self loops
        edges_new = edges_new[edges_new[:, 0] != edges_new[:, 1]]


    # each node gets 3 connections: one to closest node of each hole
    if mode == 'connect3':
        for q in range(n_quads):
            # nodes around hole in quadrant q (q-hole)
            hole_nodes = np.where(n_data['quad'] == q)[0]

            temp = []
            d_temp2 = []
            # for each node in q-hole, find nearest node in each other hole
            for q2 in range(n_quads):
                if q == q2: continue

                other_hole_nodes = np.where(n_data['quad'] == q2)[0]

                # distances between nodes around hole q and nodes around hole q2
                r_temp = pos_temp[hole_nodes].reshape(-1, 1, 2) - pos_temp[other_hole_nodes].reshape(1, -1, 2)

                r_temp[r_temp > 1.6] = -1.6*2 + r_temp[r_temp > 1.6]
                r_temp[r_temp < -1.6] = 1.6*2 + r_temp[r_temp < -1.6]
                d_temp = np.linalg.norm(r_temp, axis=-1)
                # d_temp.shape = [nr of nodes on hole q, nr of nodes on hole q2]

                inds = np.argmin(d_temp, axis=1)
                closest_nodes = other_hole_nodes[inds]

                temp.append(np.stack((hole_nodes, closest_nodes), axis=-1))
                d_temp2.append(d_temp[np.arange(len(inds)), inds])

            inds = np.argsort(d_temp2, axis=0)
            temp = np.array(temp)

            edges_new.extend(np.take_along_axis(temp, inds[:3][..., np.newaxis], axis=0))

        edges_new = np.array(edges_new).reshape(-1, 2)

    # each node gets 1 connection: one to closest node of another hole
    if mode == 'connect1':
        raise NotImplementedError('connect1 not yet implemented')
        for q in range(4):
            # nodes around hole in quadrant q (q-hole)
            hole_nodes = n_inds['hole_boundary'][quad == q]

            other_hole_nodes = n_inds['hole_boundary'][quad != q]

            # distances between nodes around hole q and nodes around all other holes
            r_temp = pos_temp[hole_nodes].reshape(-1, 1, 2) - pos_temp[other_hole_nodes].reshape(1, -1, 2)

            r_temp[r_temp > 1.6] = -1.6*2 + r_temp[r_temp > 1.6]
            r_temp[r_temp < -1.6] = 1.6*2 + r_temp[r_temp < -1.6]
            d_temp = np.linalg.norm(r_temp, axis=-1)

            closest_nodes = other_hole_nodes[np.argmin(d_temp, axis=1)]
            edges_new.extend(zip(hole_nodes, closest_nodes))

In [None]:
if remove_bulk_nodes:
    edges_new = np.array(edges_new)
    # deduplicate edges
    print('edges before deduplicating:', edges_new.shape)
    edges_new = np.sort(edges_new, axis=1)
    edges_new, counts = np.unique(edges_new, axis=0, return_counts=True)
    print('edges after deduplicating :', edges_new.shape)
    edges_new = edges_new.T

In [None]:
if remove_bulk_nodes:
    # new set of edges: original edges at the hole boundary + newly created edges
    # redefine e_data (all old data can be thrown away, needs to be recalculated)
    if mode == 'all':
        # deduplicate again
        n_boundary = e_data['edge_index'].shape[-1]
        edges_new = np.concatenate((e_data['edge_index'], edges_new), axis=1)
        edges_new, inv = np.unique(edges_new, axis=1, return_inverse=True)

        e_inds = {'hole_boundary': inv[:n_boundary]}
        e_data = {'edge_index': edges_new}
    else:
        e_inds = {'hole_boundary': np.arange(e_data['edge_index'].shape[-1])}
        e_data = {'edge_index': np.concatenate((e_data['edge_index'], edges_new), axis=1)}

    e_data['r'] = pos_temp[e_data['edge_index'][1]] - pos_temp[e_data['edge_index'][0]]
    e_data['r'][e_data['r'] > 1.6] = -1.6*2 + e_data['r'][e_data['r'] > 1.6]
    e_data['r'][e_data['r'] < -1.6] = 1.6*2 + e_data['r'][e_data['r'] < -1.6]
    e_data['r'] = e_data['r'].T
    e_data['d'] = np.linalg.norm(e_data['r'], axis=-2)


In [None]:
import matplotlib

In [None]:
%matplotlib qt

font = {
        # 'family' : 'normal',
        # 'weight' : 'bold',
        'size'   : 15}

matplotlib.rc('font', **font)

# plot all nodes in original location
# plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot nodes on hole boundary
plt.scatter(*n_data['pos'], s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = n_data['pos'][..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# plot boundary edges
b_edge_index = e_data['edge_index'][..., e_inds['hole_boundary']]
pos1 = n_data['pos'][..., b_edge_index[0]]
pos2 = pos1 + e_data['r'][..., e_inds['hole_boundary']]
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='green', zorder=-1)

# set tick interval x and y axes the same
stepsize = 1
# x
start, end = plt.gca().get_xlim()
start = np.ceil(start/stepsize)*stepsize
end = np.ceil(end/stepsize)*stepsize
plt.gca().xaxis.set_ticks(np.arange(start, end, stepsize))
plt.gca().xaxis.set_ticks(np.arange(start, end, stepsize))
# y
start, end = plt.gca().get_ylim()
start = np.ceil(start/stepsize)*stepsize
end = np.ceil(end/stepsize)*stepsize
plt.gca().yaxis.set_ticks(np.arange(start, end, stepsize))
plt.gca().yaxis.set_ticks(np.arange(start, end, stepsize))

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()

In [None]:
for stuff in [e_data, n_data, g_data, e_inds, n_inds]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
dir_results, dataframe_file = os.path.split(df_path)
path = os.path.join(dir_results,
                    'data2_'
                    + dataframe_file.split('_', maxsplit=1)[1])
with open(path, 'wb') as f:
    pickle.dump([e_data, n_data, g_data, e_inds, n_inds], f)

# Tile RVEs

In [None]:
dir_results, dataframe_file = os.path.split(df_path)
path = os.path.join(dir_results,
                    'data2_'
                    + dataframe_file.split('_', maxsplit=1)[1])
with open(path, 'rb') as f:
    e_data, n_data, g_data, e_inds, n_inds = pickle.load(f)

In [None]:
# Tile RVEs
tiling = None  # [2, 2]

if tiling is not None:
    basis_vecs = np.array([[3.2, 0], [0, 3.2]]).reshape(2, 2, 1)
    nr_loadcases = len(n_data['pos_final'])
    nr_nodes = n_data['pos'].shape[-1]
    nr_edges = e_data['edge_index'].shape[-1]

    e_data_new = {'edge_index': np.array([], dtype=int).reshape(2, 0),
                'r': np.array([], dtype=int).reshape(2, 0),
                'd': np.array([], dtype=int),}
    e_inds_new = {'hole_boundary': np.array([], dtype=int)}
    n_data_new = {'pos': np.array([]).reshape(2, 0), 'quad': np.array([], dtype=int), 'pos_final': np.array([]).reshape(nr_loadcases, 2, 0)}
    for i in range(tiling[0]):
        for j in range(tiling[1]):
            pos2 = n_data['pos'] + i*basis_vecs[0] + j*basis_vecs[1]
            quad2 = n_data['quad'] + 4*(tiling[1]*i + j)
            edge_index2 = e_data['edge_index'] + nr_nodes*(tiling[1]*i + j)

            n_data_new['pos'] = np.concatenate((n_data_new['pos'], pos2), axis=-1)
            n_data_new['quad'] = np.concatenate((n_data_new['quad'], quad2))
            e_data_new['edge_index'] = np.concatenate((e_data_new['edge_index'], edge_index2), axis=1)

            basis_vecs2 = np.einsum('ijk,lkm->ijkm', g_data['F'], basis_vecs)
            pos_final2 = n_data['pos_final'] + i*basis_vecs2[:, 0] + j*basis_vecs2[:, 1]
            n_data_new['pos_final'] = np.concatenate((n_data_new['pos_final'],
                                                    pos_final2), axis=-1)

            e_data_new['r'] = np.concatenate((e_data_new['r'], e_data['r']), axis=1)
            e_data_new['d'] = np.concatenate((e_data_new['d'], e_data['d']), axis=0)
            e_inds_new['hole_boundary'] = np.concatenate(
                (e_inds_new['hole_boundary'],
                e_inds['hole_boundary'] + nr_edges*(tiling[1]*i + j)
                ), axis=0
            )
    # move edges to different neighbors

    r_real = (n_data_new['pos'][..., e_data_new['edge_index'][1]]
            - n_data_new['pos'][..., e_data_new['edge_index'][0]])

    # check if edge too wide in x-direction
    bools = np.abs(r_real[0]) > 1.6
    e_data_new['edge_index'][1, bools] -= 2*nr_nodes
    e_data_new['edge_index'][1, bools] %= 4*nr_nodes

    # check if edge too wide in y-direction
    bools = np.abs(r_real[1]) > 1.6
    bools2 = e_data_new['edge_index'][1] < nr_nodes*2
    e_data_new['edge_index'][1, bools*bools2] += nr_nodes
    e_data_new['edge_index'][1, bools*bools2] %= 2*nr_nodes
    bools3 = e_data_new['edge_index'][1] >= nr_nodes*2
    e_data_new['edge_index'][1, bools*bools3] -= 3*nr_nodes
    e_data_new['edge_index'][1, bools*bools3] %= 2*nr_nodes
    e_data_new['edge_index'][1, bools*bools3] += 2*nr_nodes
    n_data = n_data_new
    e_data = e_data_new
    e_inds = e_inds_new


In [None]:
names = ['e_data', 'n_data', 'g_data', 'e_inds', 'n_inds']
for stuff, name in zip([e_data, n_data, g_data, e_inds, n_inds], names):
    print(f'=============== {name} ===============')
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

## Plot all edges, by r

In [None]:
%matplotlib qt

# plot all nodes in original location
# plt.scatter(*(n_data['pos']), c='black', s=1)  #, s=50, alpha=0.5)

# plot nodes on hole bound
plt.scatter(*(n_data['pos']), s=10)  #, marker='x', s=50, c=quad)

# plot edges
pos1 = n_data['pos'][..., e_data['edge_index'][0]]
pos2 = pos1 + e_data['r']
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='red', alpha=0.3, zorder=-1)

# plot boundary edges
b_edge_index = e_data['edge_index'][..., e_inds['hole_boundary']]
pos1 = n_data['pos'][..., b_edge_index[0]]
pos2 = pos1 + e_data['r'][..., e_inds['hole_boundary']]
x = np.stack((pos1[0], pos2[0]))
y = np.stack((pos1[1], pos2[1]))
plt.plot(x, y, c='green', zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()

## Plot all edges from point to point

In [None]:
## Plot meshes
%matplotlib qt
try:
    # plot all nodes in original location
    plt.scatter(*(n_data['pos']), #c='tab:blue',
                s=1, c=n_data['quad'], cmap='tab20')
except KeyError:
    # plot all nodes in original location
    plt.scatter(*(n_data['pos']), #c='tab:blue',
                s=1, cmap='tab20')

# plot edges
x, y = n_data['pos'].T[e_data['edge_index']].T
x, y = x.T, y.T
print(x.shape)
bools = (np.abs(x[0] - x[1]) < 1.6)*(np.abs(y[0] - y[1]) < 1.6)
plt.plot(x[:, bools], y[:, bools], c='tab:blue', alpha=0.3, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


In [None]:
n_data['pos_final'][0].shape

In [None]:
## Plot meshes
%matplotlib qt

# plot all nodes in final location
plt.scatter(*(n_data['pos_final'][0]), #c='tab:green',
            s=1, c='red')  #, c=n_data['quad'], cmap='tab20')  #, s=50, alpha=0.5)

# plot edges in final location
x, y = n_data['pos_final'][0].T[e_data['edge_index']].T
x, y = x.T, y.T
bools = (np.abs(x[0] - x[1]) < 1.6)*(np.abs(y[0] - y[1]) < 1.6)
plt.plot(x[:, bools], y[:, bools], c='tab:blue', alpha=0.3, zorder=-1)

# plot boundary edges in final location
x, y = n_data['pos_final'][0].T[e_data['edge_index'][..., e_inds['hole_boundary']]].T
x, y = x.T, y.T
bools = (np.abs(x[0] - x[1]) < 1.6)*(np.abs(y[0] - y[1]) < 1.6)
plt.plot(x[:, bools], y[:, bools], c='tab:green', alpha=0.3, zorder=-1)

# make plot pretty
plt.gca().set_aspect('equal')
plt.grid()


# Duplicate edges to make the graph undirected

In [None]:
n_edges = e_data['edge_index'].shape[-1]
for key in e_inds:
    e_inds[key] = np.concatenate((e_inds[key], e_inds[key]+n_edges), axis=-1)

e_data['edge_index'] = np.concatenate((e_data['edge_index'],
    np.flip(e_data['edge_index'], axis=0)), axis=-1)

e_data['r'] = np.concatenate((e_data['r'], -e_data['r']), axis=-1)
e_data['d'] = np.concatenate((e_data['d'], e_data['d']), axis=-1)

# Turn data into graphs


In [None]:
for stuff in [e_data, n_data, g_data, e_inds, n_inds]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
# # boundary info in edges:
x = np.array([]).reshape(n_data['pos'].shape[-1], -1)
x = torch.tensor(x, dtype=torch.float)

edge_attr = torch.ones((e_data['edge_index'].shape[1], 1), dtype=torch.float)
edge_attr[e_inds['hole_boundary']] = -1

# # boundary info in nodes:
# x = torch.ones(len(nodes_to_keep), 1)
# x[hole_boundary_nodes2] = -1
# edge_attr = torch.tensor([]).reshape(edges2.shape[1], 0) # edge attributes should be empty for now


In [None]:
# turn into tensors

# edge indices attributes that are the same for each graph
edge_index = torch.tensor(e_data['edge_index'], dtype=torch.long)

d = torch.tensor(e_data['d'], dtype=torch.float)[..., np.newaxis]
r = torch.tensor(e_data['r'], dtype=torch.float)

pos = torch.tensor(n_data['pos'], dtype=torch.float)

# node attributes that vary per graph
pos_final = torch.tensor(np.transpose(n_data['pos_final'], [0, 2, 1]),
                          dtype=torch.float)

# graph attributes
g_data2 = {key: torch.tensor(g_data[key], dtype=torch.float) for key in ['F', 'W', 'P', 'D']}
# for key in ['P', 'D', 'F']:
#     g_data2[key] = g_data2[key].unsqueeze(1)
g_data2['traj'] = torch.tensor(g_data['traj'], dtype=torch.long).reshape(-1)
g_data2['W'] = g_data2['W'].reshape(-1)/((2*1.6)**2)  # divide by volume to get density
g_data2['mean_pos'] = torch.mean(pos_final, dim=1) #.reshape(-1, 1, 2)

In [None]:
for stuff in [g_data2]:
    for key in stuff:
        print(f'{key+".shape":30}\t', stuff[key].shape)
    print('')

In [None]:
if constant_mesh:
    # make list of graph objects, one for each F
    data_list0 = []
    for i in range(len(pos_final)):
        data_list0.append(
            tg.data.Data(
                edge_index=edge_index,
                edge_attr=edge_attr,
                x=x,
                y=pos_final[i].clone(),
                pos=pos.T.clone(),
                r=r.T.clone(),
                d=d.clone(),
                **{key: g_data2[key][[i]].clone() for key in g_data2}
            )
        )
else:
    raise NotImplementedError('variable mesh not implemented yet')

print('data_list0[0]:', data_list0[0])
print('data_list0[0] undirected:', end=' ')
print(tg.utils.is_undirected(data_list0[0].edge_index, data_list0[0].edge_attr))

In [None]:
# check data types and size of tensors in memory
print(f'{"variable":12} {"type":23} {"size (bytes)":>13} {"shape":24} {"datatype":15} {"bytes per element"}')
for name, value in data_list0[0]:
    if value.numel() > 0:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15} {sys.getsizeof(value.storage())/value.numel()}')
    else:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15}')

# Save data

In [None]:
with open(gr_path, 'wb') as f:
    pickle.dump(data_list0, f)

In [None]:
print('Nr of nodes:\t', len(x))
print('Nr of edges:\t', len(edge_attr))

# Calculate shortest path length between two most distant nodes

In [None]:
import networkx as nx
G = nx.Graph()
G.add_edges_from(e_data['edge_index'].T)

In [None]:
next(nx.all_pairs_shortest_path_length(G))

In [None]:
maxes_temp = []
for i, [node, node_dict] in enumerate(nx.all_pairs_shortest_path_length(G)):
    maxes_temp.append(max(node_dict.values()))
print(max(maxes_temp))
longest_path_length = max(maxes_temp)

# Data augmentation

In [None]:
import pickle
import sys

import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from scipy.stats.qmc import LatinHypercube

In [None]:
gr_path = f'data\\coarseMesh_noBifurcation_5\\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_2.pkl'
with open(gr_path, 'rb') as f:
    data_list0 = pickle.load(f)


In [None]:
augmentation_factor = 1

sampler = LatinHypercube(d=4, seed=43)
sample = sampler.random(n=len(data_list0)*augmentation_factor)

In [None]:
data_list1 = []

j = 0
for graph in data_list0:
    for i in range(augmentation_factor):
        reflection1, reflection2, rotation, scaling = sample[j]

        reflection1 = np.round(reflection1)*2-1  # either -1 or 1
        reflection2 = np.round(reflection2)*2-1  # either -1 or 1
        A = np.array([[1.0*reflection1, 0], [0, 1.0*reflection2]])

        scale_factor = 10**(2*scaling-1)  # between 0.1 and 10
        S = np.array([[scale_factor, 0], [0, scale_factor]])

        phi = rotation*2*np.pi
        R = np.array([[np.cos(phi), -np.sin(phi)],
                      [np.sin(phi), np.cos(phi)]])

        T = torch.tensor(R@S@A, dtype=torch.float)
        T2 = torch.tensor(np.matmul(R, A), dtype=torch.float)

        graph2 = graph.clone()

        # apply reflection, rotation and scaling to vectors
        graph2.y = torch.matmul(T, graph2.y.T).T
        graph2.pos = torch.matmul(T, graph2.pos.T).T
        graph2.r = torch.matmul(T, graph2.r.T).T
        graph2.mean_pos = torch.matmul(T, graph2.mean_pos.T).T

        # apply scaling
        graph2.d = scale_factor*graph2.d

        # apply reflection and rotation to the tensors
        graph2.P = torch.einsum('lj,ijk,km->ilm', T2, graph2.P, T2.T)
        graph2.D = torch.einsum('nj,ok,pl,qm,ijklm->inopq', T2, T2, T2, T2, graph2.D)
        graph2.F = torch.einsum('ij,lk,mjk->mil', T2, T2, graph2.F)
        graph2.reflection = torch.tensor([reflection1, reflection2])
        graph2.phi = torch.tensor([phi])
        graph2.scale_factor = torch.tensor([scale_factor])

        data_list1.append(graph2)

        j += 1

In [None]:
# check data types and size of tensors in memory
print(f'{"variable":12} {"type":23} {"size (bytes)":>13} {"shape":24} {"datatype":15} {"bytes per element"}')

nr_of_bytes = 0
for name, value in data_list0[0]:
    if value.numel() > 0:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15} {sys.getsizeof(value.storage())/value.numel()}')
    else:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15}')

    nr_of_bytes += sys.getsizeof(value.storage())
print(nr_of_bytes)
print(nr_of_bytes*len(data_list0))

In [None]:
# check data types and size of tensors in memory
print(f'{"variable":12} {"type":23} {"size (bytes)":>13} {"shape":24} {"datatype":15} {"bytes per element"}')

nr_of_bytes = 0
for name, value in data_list1[0]:
    if value.numel() > 0:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15} {sys.getsizeof(value.storage())/value.numel()}')
    else:
        print(f'{name:12} {str(type(value)):23} {sys.getsizeof(value.storage()):>13} {str(value.shape):24} {str(value.dtype):15}')

    nr_of_bytes += sys.getsizeof(value.storage())
print(nr_of_bytes)
print(nr_of_bytes*len(data_list1))

In [None]:
%matplotlib qt

# choose one graph to plot
graph = data_list1[126].clone()
print('reflection', graph.reflection)
print('scale_factor', graph.scale_factor)
print('phi', graph.phi)

# create figure
fig, ax = plt.subplots()

# use initial position to exclude wraparound edges
graph_indices = graph.edge_index.detach().numpy()
pos_init = graph.pos.clone().detach().numpy().reshape(-1, 2)  # original
x, y = np.transpose(pos_init[graph_indices], axes=[2,0,1])
bools = ((np.abs(np.diff(x, axis=0)) < 1.6*graph.scale_factor[0].item())
            & (np.abs(np.diff(y, axis=0)) < 1.6*graph.scale_factor[0].item())
        ).flatten()

# get all relevant positions
graph.batch = torch.zeros(len(graph.x), dtype=torch.long)
# pos_pred = model(graph)[0].clone().detach().numpy()  # predicted
pos_FEM = graph.y.clone().detach().numpy()  # target

# plot nodes
ax.scatter(*(pos_FEM.T), label='final position', s=1)
# ax.scatter(*(pos_pred.T), label='GNN prediction', s=1)

# plot edges FEM
x, y = np.transpose(pos_FEM[graph_indices], axes=[2,0,1])
edges_FEM = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:green')  #, label='final position')

# plot edges predicted
# x, y = np.transpose(pos_pred[graph_indices], axes=[2,0,1])
# edges_pred = ax.plot(x[:, bools], y[:, bools], alpha=0.3, c='tab:orange')

# plot original locations corners
corner_coords = np.array([[-1.6, -1.6, 1.6, 1.6],[-1.6, 1.6, -1.6, 1.6]])
orig_corners = ax.scatter(*corner_coords, marker='x', label='original corners', c='red', s=20)

ax.set_aspect('equal')

# make plots pretty
ax.grid()
ax.axis('off')
ax.margins(0)


In [None]:
gr_path2 = 'data\\coarseMesh_noBifurcation_5\\graphs_coarseMesh_noBifurcation_diameter_0.9_5_noBulkNodes_2_augmented_×' + str(augmentation_factor) + '.pkl'
with open(gr_path2, 'wb') as f:
    pickle.dump(data_list1, f)