In [1]:
# This code needs Ludo's environment (called environment_networkx-pyvista.yml) This environment does not have karstnet installed

import pyvista as pv

from pyvista import examples

mesh = examples.load_uniform()

pl = pv.Plotter(shape=(1, 2))

_ = pl.add_mesh(

    mesh, scalars='Spatial Point Data', show_edges=True

)

pl.subplot(0, 1)

_ = pl.add_mesh(

    mesh, scalars='Spatial Cell Data', show_edges=True

)

pl.export_html('pv.html')  

  "class": algorithms.Blowfish,


In [None]:
#import karstnet as kn
# import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import sqlite3
from collections import defaultdict
import networkx as nx
from scipy import spatial
import pandas as pd
import pyvista as pv
pv.set_jupyter_backend('trame')


def get_pos2d(G):
    return {key: value[0:2] for key, value in nx.get_node_attributes(G,'coord').items()}

def get_pos3d(G):
    return nx.get_node_attributes(G,'coord')

def get_nodes_attributes(G):
    # get list of node keys
    return set([k for n in G.nodes for k in G.nodes[n].keys()])
    

def calc_projected_splay(conduit_direction,splay_direction):
    """Projects a splay (disto measurement from a station to the surrounding wall) 
    in the u,v plan perpendicular to the cave conduit direction.


    Args:
        conduit_direction (numpy.ndarray, list): list or array of [x,y,z] values 
            of the direction vector of the conduit at the point of interest.
            vector origins has to be (0,0,0).
            use function ... to determin the right direction of the vector.
        splay_direction (numpy.ndarray, list): list or array of [x,y,z] values 
            of the direction of splay. Splay values - node.
            vector origins has to be (0,0,0).

    Returns:
        numpy.ndarray: projected splay length in the u and v direction. 
            (splay vector projectedin the uv plan)
            array of two values [u,v]. 
            u is in horizontal x,y plane, and v is orthogonal to u in the uv plan.
    """
    #normalize vector w
    w = np.array(conduit_direction)
    w /= np.linalg.norm(w)
    #find normalized vector u in the horizontal plan, orthogonal to vector w
    if ((np.abs(w[0])>np.abs(w[1])) and (np.abs(w[0])!=0)):
        #checks that x is not equal to zero and that its longer than y
        u = np.array([-w[1]/w[0],1,0])
    elif ((np.abs(w[1])>np.abs(w[0])) and (np.abs(w[1])!=0)):
        u = np.array([1,-w[0]/w[1],0])
    else:
        print('%s and %s are equal to zero. Normalize vector cannot be calculated'%(w[0],w[1]))
        
    u /= np.linalg.norm(u)
    #find normalized vector v, orthogonal to u, in the uv plan
    v = np.cross(u,w)
    #calculate the length of the projected vector w in the u and v direction
    s_u = np.dot(u,splay_direction)
    s_v = np.dot(v,splay_direction)
    return np.array([s_u,s_v])





def get_conduit_direction(kg):
    keys = list(dict(kg.graph.degree()).keys())
    conduit_directions = {}

    # get conduit direction

    for i,key in enumerate(keys):
        degree = kg.graph.degree()[key]
        
        # calc w at the beginning  and end of a conduit
        if degree == 1:
            #if node is the start of the cave
            if keys[0]==key:
                #if the node $P_{i}$ is degree 1 and a starting node, $w=P_{i+1}-P_{i}$
                conduit_directions[key] = np.array(kg.pos3d[keys[i+1]])-np.array(kg.pos3d[key])

            #if node is the end of a conduit
            else:
                #if the node $P_{i}$ is degree 1 and a ending node, $w=P_{i}-P_{i-1}$
                conduit_directions[key] = np.array(kg.pos3d[key])-np.array(kg.pos3d[keys[i-1]])

        #calc w in the middle of a conduit
        if degree == 2:
            #if the node $P_{i}$ is degree 2, $w=P_{i+1}-P_{i-1}$ 
            conduit_directions[key] = np.array(kg.pos3d[keys[i+1]])-np.array(kg.pos3d[keys[i-1]])

        if degree > 2:
            # no projection for conduit intersections
            conduit_directions[key] = []
    
    return conduit_directions


def get_splay_direction(kg,splays):
    splays_direction = {} 
    keys = list(splays.keys())
    for key in keys:
    #calculate the splay direction relative to its node
        #if the node exists. some nodes were survey duplicate, removed from the main graph
        if key in kg.pos3d.keys():
            list_temp = []
            for j in np.arange(len(splays[key])):
                list_temp.append(np.array(splays[key][j])-np.array(kg.pos3d[key]))
            splays_direction[key] = list_temp

    return splays_direction

def calc_conduit_dimensions(conduit_directions,splays_direction):

    projected_splays_u = {}
    projected_splays_v = {}

    width = {}
    height = {}

    keys = list(splays_direction.keys())
    for key in keys:
    #calculate the projected splay length component in u and v direction
        #sometimes there is splays that were duplicates. 
        if key in conduit_directions:
            su=[]
            sv=[]
            if len(conduit_directions[key])>0:
                for splay_direction in splays_direction[key]:       
                    s = calc_projected_splay(conduit_directions[key],splay_direction)
                    su.append(s[0])
                    sv.append(s[1])
                # # store su and vs in a dictionnary
                # # this is only interesting for plots...
                projected_splays_u[key] = su
                projected_splays_v[key] = sv   

                # add zero to the array, in case all splays have been taken in the same direction
                su.append(0)
                sv.append(0)

                # extract the futherst coordinate on each side
                width[key]  = max(su)-min(su)
                height[key]  = max(sv)-min(sv)
        
            else:
                #for intersections, calculate the mean of the splays, instead of the projections
                for splay_direction in splays_direction[key]: 
                    su.append(np.linalg.norm(splay_direction))
                    sv.append(np.linalg.norm(splay_direction))
                width[key] = np.round(np.mean(su),2)
                height[key] = np.round(np.mean(sv),2)

    return projected_splays_u, projected_splays_v, width, height

def read_sql_file(basename):
    '''
    

    Parameters
    ----------
    basename : str
        name of the sql database. without the extension.

    Returns
    -------
    c : TYPE
        DESCRIPTION.

    '''
    sql_name = basename + '.sql'

    try:
        conn = sqlite3.connect(':memory:')
        conn.executescript(open(sql_name).read())
    #    	conn.executescript(open('../data/g_huttes.sql').read())
    except OSError:
        print("IMPORT ERROR: Could not import {}".format(sql_name))
    #    return

    # Read the SQL file
    c = conn.cursor()
    return c


# load flags
############


def extract_flags(c,string):

    flag_list=[]
    c.execute(string)
    for s in c.fetchall():
        flag_list.append([s[1],s[0]])

    flag_dict = defaultdict(list)
    for key, *values in flag_list:
        flag_dict[key].extend(values) 

    return dict(flag_dict)

def load_flag_station(basename):
    # Station Flags
    # 'ent' = entrance, 'con' = continuation, 'fix' = fixed, 
    # 'spr' = spring, 'sin' = sink, 'dol' = doline, 'dig' = dig, 
    # 'air' =air-draught, 'ove' = overhang, 'arc' = arch attributes
    c=read_sql_file(basename)
    return extract_flags(c,'select STATION_ID, FLAG from STATION_FLAG')

def load_station_name(basename):
    c = read_sql_file(basename)
    station_name = {}
    c.execute('select st.ID, st.NAME from STATION \
                st left join SURVEY su on st.SURVEY_ID = su.ID \
                where st.NAME not in (".","-")')
    for s in c.fetchall():
        station_name[s[0]] = s[1]
    return station_name

def load_flag_shot(basename):
    # Shot Flags
    # 'dpl' = duplicate, 'srf' = surface shots
    c=read_sql_file(basename)
    return extract_flags(c,'select SHOT_ID, FLAG from SHOT_FLAG')


def load_splays(basename):
    c=read_sql_file(basename)
    nodes_coord = []
    stations_id = []
    c.execute('select st.ID, st.NAME, FULL_NAME, X, Y, Z from STATION \
                st left join SURVEY su on st.SURVEY_ID = su.ID \
                where st.NAME in (".","-")')

    for s in c.fetchall():
        #extract x,y,z nodes coordinates
        nodes_coord.append([s[3], s[4], s[5]])
        #extract unique node id from Therion
        stations_id.append(s[0])
    #create dictionnary of the nodes coordinates
    coord = dict(zip(stations_id,nodes_coord))

    #import links only for the nodes we exported
    string_id = ",".join(map(str,stations_id))
    c.execute('select FROM_ID, TO_ID from SHOT \
            where TO_ID in (%s)' % (string_id))

    links = []
    for l in c.fetchall():
        links.append([l[0], l[1]])
    #links = np.asarray(links).astype(int)

    links_dict = defaultdict(list)
    for key, value in links:   
        links_dict[key].extend([coord[value]]) 

    return dict(links_dict)


def nextstep(G, path):
     """
     from karstnet

     Work on self_graph
     Adds the next node to a path of self_graph along a branch.
     Stops when reaches a node of degree different from 2.

     Parameters
     ----------
         path : list
             A list of nodes to explain a path

     Returns:
     --------
     list
         path : A list of nodes to explain a path

     bool
         stopc
     """
     current = path[-1]
     # Checks first if the end of the path is already on an end
     if G.degree(current) != 2:
         stopc = False
         return path, stopc

     # This is a security / it may be removed
     if len(path) > 1:
         old = path[-2]
     else:
         old = current

     # Among the neighbors search for the next one
     for nextn in G.neighbors(current):
         if old != nextn:
             break

     # Add the next node to the path and check stopping criteria
     # noinspection PyUnboundLocalVariable
     path.append(nextn)

     # Test for a closed loop / even if start node has degree = 2
     test_loop = path[0] == path[-1]

     if (G.degree(nextn) != 2) or test_loop:
         stopc = False
     else:
         stopc = True

     return path, stopc



def getallbranches(G):
    """
    NOT PUBLIC
    
    Constructs the list of all branches of the karstic graph self_graph.
    Compute lengths and tortuosities
    """
    # Initialisations
    target = []
    degree_target = []
    
    # Create one subgraph per connected components
    # to get isolated loops as branches
    # Return a list of connected graphs
    list_sub_gr = [G.subgraph(c).copy()
                   for c in nx.connected_components(G)]
    
    for sub_gr in list_sub_gr:
        local_counter = 0
        last_node_index = 0
        # Identifies all the extremeties of the branches (nodes of degree
        # != 2)
        for i in sub_gr.nodes():
            if (sub_gr.degree(i) != 2):
                target.append(i)
                degree_target.append(nx.degree(sub_gr, i))
                local_counter += 1
            last_node_index = i
        # to manage cases where a subgraph is only composed of nodes of
        # degree 2
        if (local_counter == 0):
            target.append(last_node_index)
            degree_target.append(nx.degree(sub_gr, last_node_index))
    
    # Identifies all the neighbors of those nodes,
    # to create all the initial paths
    list_start_branches = []
    for i in target:
        for n in G.neighbors(i):
            list_start_branches.append([i, n])
    
    # Follow all these initial paths to get all the branches
    branches = []
    for path in list_start_branches:
        go = True
        # Check all existing branches to avoid adding a branch twice
        # if starting from other extremity
        for knownbranch in branches:
            if ((path[0] == knownbranch[-1]) &
                    (path[1] == knownbranch[-2])):
                go = False
                break
        if go:
            #get a list for a single branch
            path, stopc = nextstep(G,path)
            while stopc:
                path, stopc = nextstep(G,path)
        
            branches.append(path)
        
    return branches


######################################################################
# IMPORT DATA
#####################################################################

def load_raw_therion_data(basename):
    '''
    #potential node flags !!! todo: create option to extract data from those flags
    # 'ent' = entrance, 'con' = continuation, 'fix' = fixed, 
    # 'spr' = spring, 'sin' = sink, 'dol' = doline, 'dig' = dig, 
    # 'air' =air-draught, 'ove' = overhang, 'arc' = arch attributes  
    #potential shot flags:
        but they should not be loaded for the graph.
            # this is the list of graphs 
            # 'dpl' = duplicate, 'srf' = surface shots 
    
    Parameters
    ----------
    basename : TYPE
        DESCRIPTION.
    verbose : TYPE, optional
        DESCRIPTION. The default is True.
    remove_flagged_shots : TYPE, optional
        DESCRIPTION. The default is True.
    
    Returns
    -------
    G : TYPE
        DESCRIPTION.
    
    '''
    #read the sql database
    c = read_sql_file(basename)
    
    # import all LINKS 
    ###############
    c.execute('select FROM_ID, TO_ID from SHOT')
    links_all = []
    for l in c.fetchall():
        links_all.append(l)
    
    # import NODES
    ###############################################################
    #import all nodes
     
    #prevents extraction of anonymous survey point symbol (- or .)     
    c.execute('select st.ID, st.NAME, FULL_NAME, X, Y, Z from STATION st \
              left join SURVEY su on st.SURVEY_ID = su.ID' 
                )           
    nodes_coord = []
    stations_id = []
    for s in c.fetchall():
        #extract x,y,z nodes coordinates
        nodes_coord.append([s[3], s[4], s[5]])
        #extract unique node id from Therion
        stations_id.append(s[0])
    #create dictionnary of the nodes coordinates
    coord = dict(zip(stations_id,nodes_coord))
    
    #create graph with all the links
    #################################
    G = nx.Graph()
    G.add_edges_from(links_all)
    nx.set_node_attributes(G, coord, 'coord')
    
    #splay legs
    ###################################
    #remove nodes that are anonymous survey point symbol (- or .)
    
    splay_id = []    
    c.execute('select st.ID, st.NAME, FULL_NAME, X, Y, Z from STATION st \
            left join SURVEY su on st.SURVEY_ID = su.ID \
            where st.NAME in (".","-")' )
    for s in c.fetchall():
        splay_id.append(s[0])
        
    if splay_id:
        G.remove_nodes_from(splay_id)
    else: 
        print('no splays legs to remove')
    #G.remove_nodes_from(list(splays_dict.keys()))
    
    
    #add splay leg shot info on nodes. in the form of a list of coordinates of the end of the shot.
    #load splay legs
    splays_dict = load_splays(basename)
    splays_dict = load_splays(basename)
    nx.set_node_attributes(G, splays_dict, 'splays')   
    
    #add potential node flags
    station_flags_dict = load_flag_station(basename)
        # Station Flags
        # 'ent' = entrance, 'con' = continuation, 'fix' = fixed, 
        # 'spr' = spring, 'sin' = sink, 'dol' = doline, 'dig' = dig, 
        # 'air' =air-draught, 'ove' = overhang, 'arc' = arch attributes
    nx.set_node_attributes(G, station_flags_dict, 'flags')     
    
    #Rename nodes and get ride of duplicate nodes with identical position
    ##############################################################################
    ##############################################################################
    # this rename nodes with identical position with the same id, 
    # which automatically regroup the nodes with identical name into one.
     
    #pos2d = {key: value[0:2] for key, value in nx.get_node_attributes(G,'coord').items()}
    # plt.figure()
    # nx.draw(G,pos=pos2d)
    
    #find nodes with duplicate positions:
    #create a list of lists of index where the coordinates are the same
    unique_pos = [list(x) for x in set(tuple(x) for x in list(nx.get_node_attributes(G,'coord').values()))]
    duplicates = []
    for position in unique_pos:
        duplicates.append([key for key,coord in G.nodes('coord') if coord==position])
        
    #rename nodes 
    ########################################################################
    #duplicate nodes are renamed with the same name
    
    #create new ids dictionnary to replace the initial indexes
    
    #create new ids with repeating values for idential node position 
    newis = [] 
    for i, index in enumerate(np.arange(len(duplicates))):
        newis = newis + [index]*len(duplicates[i])
    
    #flatten the list of list of old ids
    concat_oldi = [j for i in duplicates for j in i]   
    
    #the dictionnary has to be in the form of dict keys are the old keys, and the value is the new key
    index_dict = dict(zip(concat_oldi, newis ))
    
    #rename nodes (nodes with same geographic posiion will be "merged" under the same name)
    G = nx.relabel_nodes(G,index_dict)
    #drop edges that link the node to themselves. happen because of the combining the nodes.
    G.remove_edges_from(list(nx.selfloop_edges(G)))
    
    #add old therion id name as a property
    therion_ids = {}
    for k, v in zip(newis, concat_oldi):
        therion_ids.setdefault(k, []).append(v)
    nx.set_node_attributes(G, therion_ids, 'therion_id')
 
    
    #remove nodes that were isolated when removing the edges
    G.remove_nodes_from(list(nx.isolates(G)))   
    

    # #plot to see the result
    # pos2d = {key: value[0:2] for key, value in nx.get_node_attributes(G,'coord').items()}
    # plt.figure()
    # nx.draw(G,pos=pos2d)
 
    return G

def load_flagged_links(basename):
    #clean graph
    #########################################################
    ############################################################
    
    #read the sql database
    c = read_sql_file(basename)
    # load flags
    ############
    # flags will be used to ignore duplicate and surface points
    # duplicates are integreted in therion coordinates calculations,

    flag_id = []   
    c.execute('select SHOT_ID, FLAG from SHOT_FLAG where FLAG="srf" or FLAG="dpl"')
    for s in c.fetchall():
        flag_id.append(s[0]) 
        
    #import links only for the nodes we exported
    string_id = ",".join(map(str,flag_id))
    c.execute('select sh.ID, FROM_ID, TO_ID from SHOT sh\
                where sh.ID in (%s)' % (string_id))
            
    links_flagged = []
    for l in c.fetchall():
        links_flagged.append((l[1],l[2]))
        
    if links_flagged:       
        #G.remove_edges_from(links_flagged)
        
        return links_flagged

    else:
        return []
        print('no duplicate or surface flags to remove')


def find_flagged_links(basename):
    G_raw = load_raw_therion_data(basename)

    flagged_shots = load_flagged_links(basename)
    #G_raw = G.remove_edges_from(links_flagged)

    new_flag_list = []
    for old_keys in flagged_shots:
        #print(link)
        new_key_1 = [key for key, value in dict(G_raw.nodes('therion_id')).items() if old_keys[0] in value]
        new_key_2 = [key for key, value in dict(G_raw.nodes('therion_id')).items() if old_keys[1] in value]
        if new_key_1 and new_key_2:
            new_flag_list.append(new_key_1+new_key_2)


    #remove_edges_from() fails if one of the tuple to remove is not in G. This remove inexistant nodes  
    sorted_flags = [sorted(tup) for tup in new_flag_list]
    sorted_edges = [sorted(tup) for tup in G_raw.edges()]
    new_flag_list_clean = [i for i in sorted_flags if i in sorted_edges]

    return new_flag_list_clean


############################################################################
# RECONNECT DISCONNECTED COMPONENT
############################################################################
#If node is degree 1 in clean database, and degree 2 in raw database, 
# then connect with closest node in clean database.

def load_therion_without_flagged_edges(basename):
    #remove flagged edges
    G = load_raw_therion_data(basename)
    flagged_edges = find_flagged_links(basename) 
    G.remove_edges_from(flagged_edges) 
    #remove nodes that were isolated when removing the edges
    G.remove_nodes_from(list(nx.isolates(G))) 
    return G

def reconnected_components(basename):
    G_raw = load_raw_therion_data(basename)
    
    G = load_therion_without_flagged_edges(basename)

    print( 'There is ', nx.number_connected_components(G), 'disconnected components')
    
    closeby_all = []
    keys_disconnected_all =[]
    
    #cc_number = 0
    if nx.is_connected(G) == False:
        #iterate through the connectec components to find nodes where disconnection occured
        #search for the nodes that used to be degree >1 and are now degree 1.
        for i, subgraph_index in enumerate(nx.connected_components(G)):
            #print(i,subgraph_index )
            subgraph = nx.subgraph(G, subgraph_index)

            keys_disconnected_subgraph=[]   
            #find all the nodes where disconnection happened
            #look for all the nodes degree smaller in the cleaned file than in the original file 
            #keys_disconnected_subgraph = [k for k, v in dict(subgraph.degree()).items() if v == 1 and G_raw.degree()[k] >1]   
            for k in subgraph.nodes(): #dict(subgraph.degree()).items():
                #print(k)
                if subgraph.degree()[k]==1 and G_raw.degree()[k] >1:
                    keys_disconnected_subgraph.append(k)
                    #print(v,'>',G_raw.degree()[k])
                
            keys_disconnected_all = keys_disconnected_all + keys_disconnected_subgraph
                
            #find closeby point not in component£
            subpos3d = {key: value for key, value in nx.get_node_attributes(G,'coord').items() if key not in list(subgraph.nodes()) }
            #print(len(subpos3d))
            tree = spatial.KDTree(list(subpos3d.values())) 
            list_closeby = []
            for key in keys_disconnected_subgraph:          
                dist, ind = tree.query(nx.get_node_attributes(G,'coord')[key], 2, distance_upper_bound=30 )
                #select the second one
                #if no points are found inside the upper bound distance set, 
                #it saves the largest index as a flag. so ignore point those flags
                if ind[1] != len(subpos3d):            
                    #print(key,ind, dist)
                    #print(int(ind[1]))
                    key_close = list(subpos3d)[ind[1]]
                
                    print(key, key_close, dist[1])
                    # if dist<70:
                    #     list_closeby.append([keys[0], keys[1], dist])        
                    list_closeby.append([key, key_close, dist[1]])
            
            closeby_all = closeby_all + list_closeby   
        
        coordinates_disconnected_all = [nx.get_node_attributes(G,'coord')[key] for key in keys_disconnected_all]
        pd.DataFrame({'keys':keys_disconnected_all, 
                                'x': [item[0] for item in coordinates_disconnected_all],
                                'y': [item[1] for item in coordinates_disconnected_all],
                                'z': [item[2] for item in coordinates_disconnected_all] }).to_csv('list_disco_nodes.csv', index=False)
        
        pd.DataFrame({'key_1': [item[0] for item in closeby_all],
                      'key_2': [item[1] for item in closeby_all],
                      'key_3': [item[2] for item in closeby_all]}).to_csv('list_reco_nodes.csv', index=False)
        
          
        
        #G.add_edges_from([[item[0],item[1]] for item in closeby_all])  
        keys = [(item[0],item[1]) for item in closeby_all]
        
        #nx.set_edge_attributes( G, {(item[0],item[1]):{'flag':'infered'} for item in closeby_all} )  
    else:
        print('There is no disconnected components, no need to merge')
        
    return keys , keys_disconnected_all


#%%

#basename = 'data/data_GouffreDejaVu/GouffreDejaVu'
basename = 'data/data_CuevaPirates/CuevaPirates'
#basename = 'data/data_AvenDesArchesPerdues/AvenDesArchesPerdues'
#basename = 'data/data_migovec/system_migovec'
#basename = 'data/data_CheddarCatchment/Charterhouse/Charterhouse'
#######################################################################################

G_raw = load_raw_therion_data(basename)

G = load_therion_without_flagged_edges(basename)

keys_manual = [(96,31),(501,442),(300,612)]

#%%
pl = pv.Plotter()


lines = np.array([G_raw.nodes('coord')[i] for i in np.hstack(G_raw.edges())])
pl.add_lines(lines, color='grey', width=4,connected=False)

import random

points = np.array([value for value in iter(dict(G.nodes('coord')).values())])
points_raw = np.array([value for value in iter(dict(G_raw.nodes('coord')).values())])

pl.add_points(points, render_points_as_spheres=True, point_size=7, color='black')
#name the points based on the real index

if nx.number_connected_components(G)>0:
    #colors = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(nx.number_connected_components(G))]
    colors = [ "mediumblue" , "darkturquoise", "steelblue","dodgerblue","deepskyblue","aquamarine","springgreen","seagreen","orange",'gold',"olivedrab","darkgoldenrod","tan",
              "mediumblue" , "darkturquoise", "steelblue","dodgerblue","deepskyblue","aquamarine","springgreen","seagreen","orange",'gold',"olivedrab","darkgoldenrod","tan",
              "mediumblue" , "darkturquoise", "steelblue","dodgerblue","deepskyblue","aquamarine","springgreen","seagreen","orange",'gold',"olivedrab","darkgoldenrod","tan",
              "mediumblue" , "darkturquoise", "steelblue","dodgerblue","deepskyblue","aquamarine","springgreen","seagreen","orange",'gold',"olivedrab","darkgoldenrod","tan",
              "mediumblue" , "darkturquoise", "steelblue","dodgerblue","deepskyblue","aquamarine","springgreen","seagreen","orange",'gold',"olivedrab","darkgoldenrod","tan",]
    
    for i, subgraph_index in enumerate(nx.connected_components(G)):
        subgraph = nx.subgraph(G, subgraph_index)
        lines = np.array([subgraph.nodes('coord')[i] for i in np.hstack(subgraph.edges())])
        pl.add_lines(lines, color=colors[i], width=7, connected=False)
       
    reco_keys, disco_keys = reconnected_components(basename)

    lines_reco = np.array([G.nodes('coord')[i] for i in np.hstack(reco_keys)])    
    pl.add_lines(lines_reco, color='purple', width=4, connected=False)
    
    lines_reco_manual = np.array([G.nodes('coord')[i] for i in np.hstack(keys_manual)])    
    pl.add_lines(lines_reco_manual, color='red', width=6, connected=False)
    
    points_disco = np.array([G.nodes('coord')[i] for i in disco_keys])
    pl.add_points(points_disco, render_points_as_spheres=True, point_size=15, color='red')

else:
    pl.add_lines(lines, color='grey', width=15, connected=False)
    
    
pl.add_point_labels(
        points,
        G.nodes(),
        always_visible=True,
        fill_shape=False,
        margin=100,
        shape_opacity=0.0,
        font_size=20)

pl.add_point_labels(
        points_raw,
        G_raw.nodes(),
        always_visible=True,
        fill_shape=False,
        margin=100,
        shape_opacity=0.0,
        font_size=20)
    

#pl.camera_position = 'xy'
pl.show()

#pl.export_html('test.html')

#plot3dgraph(G)


#%%



#%% loop through branches

# #branches = getallbranches(G)
# vertices = list(dict(G.nodes('coord')).values())

# pl = pv.Plotter()
# #points = np.array([[0, 1, 0], [1, 0, 0], [1, 1, 0], [2, 0, 0]])
# #for branch in branches:
# points = np.array([vertices[i] for i in np.hstack(G.edges())])
# pl.add_lines(points, color='purple', width=7, connected=False)
# pl.add_points(points, render_points_as_spheres=True, point_size=15, color='black')

# pl.add_point_labels(
#         points,
#         np.hstack(G.edges()),
#         always_visible=True,
#         fill_shape=False,
#         margin=100,
#         shape_opacity=0.0,
#         font_size=20)
    

# pl.camera_position = 'xy'
# pl.show()

#%%


# branches = getallbranches(G)
# vertices = list(dict(G.nodes('coord')).values())

# import numpy as np
# import pyvista as pv
# pl = pv.Plotter()
# #points = np.array([[0, 1, 0], [1, 0, 0], [1, 1, 0], [2, 0, 0]])
# for branch in branches:
#     points = np.array([vertices[i] for i in branch])
#     pl.add_lines(points, color='purple', width=7, connected=True)
#     pl.add_points(points, render_points_as_spheres=True, point_size=15, color='black')
    
#     pl.add_point_labels(
#         points,
#         branch,
#         always_visible=True,
#         fill_shape=False,
#         margin=100,
#         shape_opacity=0.0,
#         font_size=20)
    

# pl.camera_position = 'xy'
# pl.show()
# #%%

# import pyvista as pv

# mesh = pv.MultipleLines(points=[[0, 0, 0], [1, 1, 1], [0, 0, 1],  [1, 0, 1]])

# plotter = pv.Plotter()

# actor = plotter.add_mesh(mesh, color='k', line_width=10)

# plotter.camera.azimuth = 45

# plotter.camera.zoom(0.8)

# plotter.show()


# #%%

# import numpy as np
# import pyvista
# vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 0.5, 0], [0, 0.5, 0]])
# lines = np.hstack([[3, 0, 1,2],[3, 2, 1,3], [2,3, 2]])
# mesh = pyvista.PolyData(vertices, lines=lines)

# plotter = pyvista.Plotter()
# actor = plotter.add_mesh(mesh, color='k', line_width=10)
# plotter.camera.azimuth = 45
# plotter.camera.zoom(0.8)
# plotter.show()


#%% VTK export

#https://stackoverflow.com/questions/62888678/saving-a-3d-graph-generated-in-networkx-to-vtk-format-for-viewing-in-paraview

# pos = [[0.1, 2, 0.3],    [40, 0.5, -10],
#         [0.1, -40, 0.3],  [-49, 0.1, 2],
#         [10.3, 0.3, 0.4], [-109, 0.3, 0.4]]

# ed_ls = [(x, y) for x, y in zip(range(0, 5), range(1, 6))]

# pos = get_pos3d(G).values()
# ed_ls = list(G.edges())
# nxpos = nx.spring_layout(G)
# nxpts = [nxpos[pt] for pt in sorted(nxpos)]
# nx.draw(G, with_labels=True, pos=nxpos)
# plt.show()

#raw_lines = [(pos[x],pos[y]) for x, y in ed_ls]
# nx_lines = []
# for x, y in ed_ls:
#     p1 = nxpos[x].tolist() + [0] # add z-coord
#     p2 = nxpos[y].tolist() + [0]
#     nx_lines.append([p1,p2])

# import vedo as vo
# raw_pts = vo.Points(pos, r=12)
# raw_edg = vo.Lines(raw_lines).lw(2)
# vo.show(raw_pts, raw_edg, raw_pts.labels('id'),
#       at=0, N=2, axes=True, sharecam=False)

# nx_pts = vo.Points(nxpts, r=12)
# nx_edg = vo.Lines(raw_lines).lw(2)
# vo.show(nx_pts, nx_edg, nx_pts.labels('id'),
#       at=1, interactive=True)

# vo.write(nx_edg, 'afile.vtk') # save the lines



#%%

# from pyvista import examples
# grid = examples.cells.PolyLine()
# examples.plot_cell(grid)


# #%%
# import numpy as np
# import pyvista
# import networkx as nx


# pos3d = get_pos3d(G)



# cells = [4, 0, 1]
# celltypes = [pyvista.CellType.POLY_LINE]
# points =  [[1.0, 1.0, 1.0],  [1.0, -1.0, -1.0],
#           [-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0],]
# grid = pyvista.UnstructuredGrid(cells, celltypes, points)
# grid.plot(show_edges=True)

#%%
# #point_cloud = np.array([pos3d[v] for v in np.arange(len(G))])
# point_cloud = np.array([pos3d[v] for v in G.nodes()])
# #edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])

# #point_cloud = np.random.random((100, 3))
# pdata = pyvista.PolyData(point_cloud)
# #pdata['orig_sphere'] = np.arange(len(G))

# # create many spheres from the point cloud
# sphere = pyvista.Sphere(radius=0.5, phi_resolution=10, theta_resolution=10)
# pc = pdata.glyph(scale=False, geom=sphere, orient=False)
# pc.plot(cmap='Reds')

# Dataset used for arbitrary combinations of all possible cell types.

# Can be initialized by the following:

# - Creating an empty grid
# - From a ``vtk.vtkPolyData`` or ``vtk.vtkStructuredGrid`` object
# - From cell, cell types, and point arrays
# - From a file

# Parameters
# ----------
# args : str, vtk.vtkUnstructuredGrid, iterable
#     See examples below.
# deep : bool, default: False
#     Whether to deep copy a vtkUnstructuredGrid object.
#     Default is ``False``.  Keyword only.


#%%


# import pyvis.network as pn


# points = np.column_stack((x, y, z))
# spline = pv.Spline(points, 1000)

# spline.plot(

#     render_lines_as_tubes=True,

#     line_width=10,

#     show_scalar_bar=False,

# )

# g = pn.Network()
# g.toggle_hide_edges_on_drag(False)
# g.barnes_hut()
# g.from_nx(G)
# g.show("ex.html")

# #%% basic matplotlib

# # import networkx as nx
# # import numpy as np
# # import matplotlib.pyplot as plt
# # from mpl_toolkits.mplot3d import Axes3D

# # # The graph to visualize
# # G = nx.cycle_graph(20)

# # # 3d spring layout
# # pos = nx.spring_layout(G, dim=3, seed=779)
# # Extract node and edge positions from the layout

# node_xyz = np.array([pos[v] for v in sorted(G)])
# edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])

# # Create the 3D figure
# fig = plt.figure()
# ax = fig.add_subplot(111, projection="3d")

# # Plot the nodes - alpha is scaled by "depth" automatically
# ax.scatter(*node_xyz.T, s=100, ec="w")

# # Plot the edges
# for vizedge in edge_xyz:
#     ax.plot(*vizedge.T, color="tab:gray")


# def _format_axes(ax):
#     """Visualization options for the 3D axes."""
#     # Turn gridlines off
#     ax.grid(False)
#     # Suppress tick labels
#     for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
#         dim.set_ticks([])
#     # Set axes labels
#     ax.set_xlabel("x")
#     ax.set_ylabel("y")
#     ax.set_zlabel("z")


# _format_axes(ax)
# fig.tight_layout()
# plt.show()




#%%###########################################################################
# calculate graph geometry and attach it as a attribute to the graph
############################################################################




































There is  2 disconnected components
501 495 22.94095682398622
96 64 9.137018113148299


Widget(value="<iframe src='http://localhost:55196/index.html?ui=P_0x290eda38550_1&reconnect=auto' style='width…

Exception raised
ConnectionResetError('Cannot write to closing transport')
Traceback (most recent call last):
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\wslink\protocol.py", line 340, in onMessage
    await self.sendWrappedMessage(
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\wslink\protocol.py", line 484, in sendWrappedMessage
    await ws.send_str(encMsg)
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\aiohttp\web_ws.py", line 336, in send_str
    await self._writer.send(data, binary=False, compress=compress)
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\aiohttp\http_websocket.py", line 723, in send
    await self._send_frame(message, WSMsgType.TEXT, compress)
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\aiohttp\http_websocket.py", line 686, in _send_frame
    self._write(header + message)
  File "c:\Users\celia\anaconda3\envs\ludo\Lib\site-packages\aiohttp\http_websocket.py", line 696, in _write
    raise 